@@ -11,12 +11,6 @@ using Zygote
1111using MatrixAlgebraKit
1212using MatrixAlgebraKit: LAPACK_HouseholderQR, LAPACK_HouseholderLQ, diagview
1313
14- const _repartition = @static if isdefined (Base, :get_extension )
15- Base. get_extension (TensorKit, :TensorKitChainRulesCoreExt ). _repartition
16- else
17- TensorKit. TensorKitChainRulesCoreExt. _repartition
18- end
19-
2014# Test utility
2115# -------------
2216function ChainRulesTestUtils. rand_tangent (rng:: AbstractRNG , x:: AbstractTensorMap )
4034precision (:: Type{<:Union{Float32, Complex{Float32}}} ) = 1.0e-2
4135precision (:: Type{<:Union{Float64, Complex{Float64}}} ) = 1.0e-5
4236
43- function randindextuple (N:: Int , k:: Int = rand (0 : N))
44- @assert 0 ≤ k ≤ N
45- _p = randperm (N)
46- return (tuple (_p[1 : k]. .. ), tuple (_p[(k + 1 ): end ]. .. ))
47- end
48-
4937function test_ad_rrule (f, args... ; check_inferred = false , kwargs... )
5038 test_rrule (
5139 Zygote. ZygoteRuleConfig (), f, args... ;
133121# Tests
134122# -----
135123
136- ChainRulesTestUtils. test_method_tables ()
137-
138124spacelist = (
139125 (ℂ^ 2 , (ℂ^ 3 )' , ℂ^ 3 , ℂ^ 2 , (ℂ^ 2 )' ),
140126 (
@@ -178,250 +164,14 @@ for V in spacelist
178164 I = sectortype (eltype (V))
179165 Istr = type_repr (I)
180166 eltypes = isreal (sectortype (eltype (V))) ? (Float64, ComplexF64) : (ComplexF64,)
181- symmetricbraiding = BraidingStyle (sectortype (eltype (V))) isa SymmetricBraiding
182167 println (" ---------------------------------------" )
183168 println (" Auto-diff with symmetry: $Istr " )
184169 println (" ---------------------------------------" )
185170 @timedtestset " AD with symmetry $Istr " verbose = true begin
186171 V1, V2, V3, V4, V5 = V
187172 W = V1 ⊗ V2
188- @timedtestset " Basic utility" begin
189- T1 = randn (Float64, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ])
190- T2 = randn (ComplexF64, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ])
191-
192- P1 = ProjectTo (T1)
193- @test P1 (T1) == T1
194- @test P1 (T2) == real (T2)
195-
196- test_rrule (copy, T1)
197- test_rrule (copy, T2)
198- test_rrule (TensorKit. copy_oftype, T1, ComplexF64)
199- if symmetricbraiding
200- test_rrule (convert, Array, T1)
201- test_rrule (
202- TensorMap, convert (Array, T1), codomain (T1), domain (T1);
203- fkwargs = (; tol = Inf )
204- )
205- end
206-
207- test_rrule (Base. getproperty, T1, :data )
208- test_rrule (TensorMap{scalartype (T1)}, T1. data, T1. space)
209- test_rrule (Base. getproperty, T2, :data )
210- test_rrule (TensorMap{scalartype (T2)}, T2. data, T2. space)
211- end
212-
213- @timedtestset " Basic utility (DiagonalTensor)" begin
214- for v in V
215- rdim = reduceddim (v)
216- D1 = DiagonalTensorMap (randn (rdim), v)
217- D2 = DiagonalTensorMap (randn (rdim), v)
218- D = D1 + im * D2
219- T1 = TensorMap (D1)
220- T2 = TensorMap (D2)
221- T = T1 + im * T2
222-
223- # real -> real
224- P1 = ProjectTo (D1)
225- @test P1 (D1) == D1
226- @test P1 (T1) == D1
227-
228- # complex -> complex
229- P2 = ProjectTo (D)
230- @test P2 (D) == D
231- @test P2 (T) == D
232-
233- # real -> complex
234- @test P2 (D1) == D1 + 0 * im * D1
235- @test P2 (T1) == D1 + 0 * im * D1
236-
237- # complex -> real
238- @test P1 (D) == D1
239- @test P1 (T) == D1
240-
241- test_rrule (DiagonalTensorMap, D1. data, D1. domain)
242- test_rrule (DiagonalTensorMap, D. data, D. domain)
243- test_rrule (Base. getproperty, D, :data )
244- test_rrule (Base. getproperty, D1, :data )
245-
246- test_rrule (DiagonalTensorMap, rand! (T1))
247- test_rrule (DiagonalTensorMap, randn! (T))
248- end
249- end
250-
251- @timedtestset " Basic Linear Algebra with scalartype $T " for T in eltypes
252- A = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ] ⊗ V[5 ])
253- B = randn (T, space (A))
254-
255- test_rrule (real, A)
256- test_rrule (imag, A)
257-
258- test_rrule (+ , A, B)
259- test_rrule (- , A)
260- test_rrule (- , A, B)
261-
262- α = randn (T)
263- test_rrule (* , α, A)
264- test_rrule (* , A, α)
265-
266- C = randn (T, domain (A), codomain (A))
267- test_rrule (* , A, C)
268-
269- test_rrule (transpose, A, ((2 , 5 , 4 ), (1 , 3 )))
270- symmetricbraiding && test_rrule (permute, A, ((1 , 3 , 2 ), (5 , 4 )))
271- test_rrule (twist, A, 1 )
272- test_rrule (twist, A, [1 , 3 ])
273-
274- test_rrule (flip, A, 1 )
275- test_rrule (flip, A, [1 , 3 , 4 ])
276-
277- D = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ])
278- E = randn (T, V[4 ] ← V[5 ])
279- symmetricbraiding && test_rrule (⊗ , D, E)
280- end
281-
282- @timedtestset " Linear Algebra part II with scalartype $T " for T in eltypes
283- atol = precision (T)
284- rtol = precision (T)
285- for i in 1 : 3
286- E = randn (T, ⊗ (V[1 : i]. .. ) ← ⊗ (V[1 : i]. .. ))
287- test_rrule (LinearAlgebra. tr, E; atol, rtol)
288- test_rrule (exp, E; check_inferred = false , atol, rtol)
289- test_rrule (inv, E; atol, rtol)
290- end
291-
292- A = randn (T, V[1 ] ⊗ V[2 ] ← V[3 ] ⊗ V[4 ] ⊗ V[5 ])
293- test_rrule (LinearAlgebra. adjoint, A; atol, rtol)
294- test_rrule (LinearAlgebra. norm, A, 2 ; atol, rtol)
295-
296- B = randn (T, space (A))
297- test_rrule (LinearAlgebra. dot, A, B; atol, rtol)
298- end
299-
300- @timedtestset " Matrix functions ($T )" for T in eltypes
301- atol = precision (T)
302- rtol = precision (T)
303- for f in (sqrt, exp)
304- check_inferred = false # !(T <: Real) # not type-stable for real functions
305- t1 = randn (T, V[1 ] ← V[1 ])
306- t2 = randn (T, V[2 ] ← V[2 ])
307- d = DiagonalTensorMap {T} (undef, V[1 ])
308- d2 = DiagonalTensorMap {T} (undef, V[1 ])
309- d3 = DiagonalTensorMap {T} (undef, V[1 ])
310- if (T <: Real && f === sqrt)
311- # ensuring no square root of negative numbers
312- randexp! (d. data)
313- d. data .+ = 5
314- randexp! (d2. data)
315- d2. data .+ = 5
316- randexp! (d3. data)
317- d3. data .+ = 5
318- else
319- randn! (d. data)
320- randn! (d2. data)
321- randn! (d3. data)
322- end
323-
324- test_rrule (f, t1; rrule_f = Zygote. rrule_via_ad, check_inferred, atol, rtol)
325- test_rrule (f, t2; rrule_f = Zygote. rrule_via_ad, check_inferred, atol, rtol)
326- test_rrule (f, d ⊢ d2; check_inferred, output_tangent = d3, atol, rtol)
327- end
328- end
329-
330- symmetricbraiding &&
331- @timedtestset " TensorOperations with scalartype $T " for T in eltypes
332- atol = precision (T)
333- rtol = precision (T)
334-
335- @timedtestset " tensortrace!" begin
336- for _ in 1 : 5
337- k1 = rand (0 : 2 )
338- k2 = rand (1 : 2 )
339- V1 = map (v -> rand (Bool) ? v' : v, rand (V, k1))
340- V2 = map (v -> rand (Bool) ? v' : v, rand (V, k2))
341-
342- (_p, _q) = randindextuple (k1 + 2 * k2, k1)
343- p = _repartition (_p, rand (0 : k1))
344- q = _repartition (_q, k2)
345- ip = _repartition (invperm (linearize ((_p, _q))), rand (0 : (k1 + 2 * k2)))
346- A = randn (T, permute (prod (V1) ⊗ prod (V2) ← prod (V2), ip))
347-
348- α = randn (T)
349- β = randn (T)
350- for conjA in (false , true )
351- C = randn! (TensorOperations. tensoralloc_add (T, A, p, conjA, Val (false )))
352- test_rrule (tensortrace!, C, A, p, q, conjA, α, β; atol, rtol)
353- end
354- end
355- end
356-
357- @timedtestset " tensoradd!" begin
358- A = randn (T, V[1 ] ⊗ V[2 ] ← V[4 ] ⊗ V[5 ])
359- α = randn (T)
360- β = randn (T)
361-
362- # repeat a couple times to get some distribution of arrows
363- for _ in 1 : 5
364- p = randindextuple (numind (A))
365-
366- C1 = randn! (TensorOperations. tensoralloc_add (T, A, p, false , Val (false )))
367- test_rrule (tensoradd!, C1, A, p, false , α, β; atol, rtol)
368-
369- C2 = randn! (TensorOperations. tensoralloc_add (T, A, p, true , Val (false )))
370- test_rrule (tensoradd!, C2, A, p, true , α, β; atol, rtol)
371-
372- A = rand (Bool) ? C1 : C2
373- end
374- end
375-
376- @timedtestset " tensorcontract!" begin
377- for _ in 1 : 5
378- d = 0
379- local V1, V2, V3
380- # retry a couple times to make sure there are at least some nonzero elements
381- for _ in 1 : 10
382- k1 = rand (0 : 3 )
383- k2 = rand (0 : 2 )
384- k3 = rand (0 : 2 )
385- V1 = prod (v -> rand (Bool) ? v' : v, rand (V, k1); init = one (V[1 ]))
386- V2 = prod (v -> rand (Bool) ? v' : v, rand (V, k2); init = one (V[1 ]))
387- V3 = prod (v -> rand (Bool) ? v' : v, rand (V, k3); init = one (V[1 ]))
388- d = min (dim (V1 ← V2), dim (V1' ← V2), dim (V2 ← V3), dim (V2' ← V3))
389- d > 0 && break
390- end
391- ipA = randindextuple (length (V1) + length (V2))
392- pA = _repartition (invperm (linearize (ipA)), length (V1))
393- ipB = randindextuple (length (V2) + length (V3))
394- pB = _repartition (invperm (linearize (ipB)), length (V2))
395- pAB = randindextuple (length (V1) + length (V3))
396-
397- α = randn (T)
398- β = randn (T)
399- V2_conj = prod (conj, V2; init = one (V[1 ]))
400-
401- for conjA in (false , true ), conjB in (false , true )
402- A = randn (T, permute (V1 ← (conjA ? V2_conj : V2), ipA))
403- B = randn (T, permute ((conjB ? V2_conj : V2) ← V3, ipB))
404- C = randn! (
405- TensorOperations. tensoralloc_contract (
406- T, A, pA, conjA, B, pB, conjB, pAB, Val (false )
407- )
408- )
409- test_rrule (
410- tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β;
411- atol, rtol
412- )
413- end
414- end
415- end
416-
417- @timedtestset " tensorscalar" begin
418- A = randn (T, ProductSpace {typeof(V[1]), 0} ())
419- test_rrule (tensorscalar, A)
420- end
421- end
422173
423174 @timedtestset " Factorizations" begin
424- W = V[1 ] ⊗ V[2 ]
425175 @testset " QR" begin
426176 for T in eltypes,
427177 t in (
@@ -456,16 +206,6 @@ for V in spacelist
456206 fkwargs, atol, rtol, output_tangent = ΔQ
457207 )
458208 test_ad_rrule (last ∘ qr_full, t; fkwargs, atol, rtol, output_tangent = ΔR)
459-
460- # TODO : figure out the following:
461- # N = qr_null(t)
462- # ΔN = Q * rand(T, domain(Q) ← domain(N))
463- # test_ad_rrule(qr_null, t; fkwargs, atol, rtol, output_tangent=ΔN)
464-
465- # if fuse(domain(t)) ≺ fuse(codomain(t))
466- # _, null_pb = Zygote.pullback(qr_null, t)
467- # @test_logs (:warn, r"^`qr") match_mode = :any null_pb(rand_tangent(N))
468- # end
469209 end
470210 end
471211
@@ -504,17 +244,6 @@ for V in spacelist
504244 fkwargs, atol, rtol, output_tangent = ΔL
505245 )
506246 test_ad_rrule (last ∘ lq_full, t; fkwargs, atol, rtol, output_tangent = ΔQ)
507-
508- # TODO : figure out the following
509- # Nᴴ = lq_null(t)
510- # ΔN = rand(T, codomain(Nᴴ) ← codomain(Q)) * Q
511- # test_ad_rrule(lq_null, t; fkwargs, atol, rtol, output_tangent=Nᴴ)
512-
513- # if fuse(codomain(t)) ≺ fuse(domain(t))
514- # _, null_pb = Zygote.pullback(lq_null, t)
515- # # broken due to typo in MAK
516- # # @test_logs (:warn, r"^`lq") match_mode = :any null_pb(rand_tangent(Nᴴ))
517- # end
518247 end
519248 end
520249
@@ -614,17 +343,6 @@ for V in spacelist
614343 @test g1 ≈ g2
615344 end
616345 end
617-
618- # let D = LinearAlgebra.eigvals(C)
619- # ΔD = diag(randn(complex(scalartype(C)), space(C)))
620- # test_rrule(LinearAlgebra.eigvals, C; atol, output_tangent=ΔD,
621- # fkwargs=(; sortby=nothing))
622- # end
623-
624- # let S = LinearAlgebra.svdvals(C)
625- # ΔS = diag(randn(real(scalartype(C)), space(C)))
626- # test_rrule(LinearAlgebra.svdvals, C; atol, output_tangent=ΔS)
627- # end
628346 end
629347 end
630348end
657375 grad4, = Zygote. gradient (g, convert (Array, B₀))
658376 @test convert (Array, grad3) ≈ grad4
659377end
660-
661- # https://github.com/quantumkithub/TensorKit.jl/issues/209
662- @testset " Issue #209" begin
663- function f (T, D)
664- @tensor T[1 , 4 , 1 , 3 ] * D[3 , 4 ]
665- end
666- V = Z2Space (2 , 2 )
667- D = DiagonalTensorMap (randn (4 ), V)
668- T = randn (V ⊗ V ← V ⊗ V)
669- g1, = Zygote. gradient (f, T, D)
670- g2, = Zygote. gradient (f, T, TensorMap (D))
671- @test g1 ≈ g2
672- end
0 commit comments