Skip to content

Commit 000786e

Browse files
committed
reduce tensor product and deligne product tensor sizes
1 parent e755a81 commit 000786e

2 files changed

Lines changed: 19 additions & 29 deletions

File tree

test/cuda/tensors.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,8 @@ for V in spacelist
533533
@timedtestset "Tensor product: test via norm preservation" begin
534534
for T in (ComplexF64,) # Float32 case broken because of cuTENSOR
535535
@time "Construction" begin
536-
if UnitStyle(I) isa SimpleUnit || !isempty(blocksectors(V2 V1))
537-
t1 = CUDA.rand(T, V2 V3 V1, V1)
538-
t2 = CUDA.rand(T, V2 V1 V3, V1)
539-
else
540-
t1 = CUDA.rand(T, V3 V4 V5, V1')
541-
t2 = CUDA.rand(T, (V3 V4 V5)', V1)
542-
end
536+
t1 = CUDA.rand(T, V1, V5')
537+
t2 = CUDA.rand(T, V2 V3, V4')
543538
end
544539
@time "Product" begin
545540
t = @constinferred (t1 t2)
@@ -551,8 +546,8 @@ for V in spacelist
551546
end
552547
symmetricbraiding && @timedtestset "Tensor product: test via conversion" begin
553548
for T in (Float32, ComplexF64)
554-
t1 = CUDA.rand(T, V2 V3 V1, V1)
555-
t2 = CUDA.rand(T, V2 V1 V3, V2)
549+
t1 = CUDA.rand(T, V1, V5')
550+
t2 = CUDA.rand(T, V2 V3, V4')
556551
d1 = dim(codomain(t1))
557552
d2 = dim(codomain(t2))
558553
d3 = dim(domain(t1))
@@ -564,11 +559,11 @@ for V in spacelist
564559
end
565560
symmetricbraiding && @timedtestset "Tensor product: test via tensor contraction" begin
566561
for T in (Float32, ComplexF64)
567-
t1 = CUDA.rand(T, V2 V3 V1)
568-
t2 = CUDA.rand(T, V2 V1 V3)
562+
t1 = CUDA.rand(T, V1, V5')
563+
t2 = CUDA.rand(T, V2 V3, V4')
569564
t = @constinferred (t1 t2)
570-
@tensor t′[1, 2, 3, 4, 5, 6] := t1[1, 2, 3] * t2[4, 5, 6]
571-
# @test t ≈ t′ # TODO broken for symmetry: Irrep[ℤ₃]
565+
@tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5]
566+
@test t t′ # This should really not be broken
572567
end
573568
end
574569
end
@@ -581,8 +576,8 @@ end
581576
V1, V2, V3, V4, V5 = Vlist1
582577
W1, W2, W3, W4, W5 = Vlist2
583578
for T in (Float32, ComplexF64)
584-
t1 = rand(T, V1 V2, V3' V4)
585-
t2 = rand(T, W2, W1 W1')
579+
t1 = rand(T, V2 V3, (V4 V5)')
580+
t2 = rand(T, W2, (W3 W4)')
586581
t = @constinferred (t1 t2)
587582
d1 = dim(codomain(t1))
588583
d2 = dim(codomain(t2))

test/tensors/contractions.jl

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,17 @@ for V in spacelist
9393
end
9494
@timedtestset "Tensor product: test via norm preservation" begin
9595
for T in (Float32, ComplexF64)
96-
if UnitStyle(I) isa SimpleUnit || !isempty(blocksectors(V2 V1))
97-
t1 = rand(T, V2 V3 V1, V1 V2)
98-
t2 = rand(T, V2 V1 V3, V1 V1)
99-
else
100-
t1 = rand(T, V3 V4 V5, (V1 V2)')
101-
t2 = rand(T, (V3 V4 V5)', V1 V2)
102-
end
96+
t1 = rand(T, V1, V5')
97+
t2 = rand(T, V2 V3, V4')
10398
t = @constinferred (t1 t2)
10499
@test norm(t) norm(t1) * norm(t2)
105100
end
106101
end
107102
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
108103
@timedtestset "Tensor product: test via conversion" begin
109104
for T in (Float32, ComplexF64)
110-
t1 = rand(T, V2 V3 V1, V1)
111-
t2 = rand(T, V2 V1 V3, V2)
105+
t1 = rand(T, V1, V5')
106+
t2 = rand(T, V2 V3, V4')
112107
t = @constinferred (t1 t2)
113108
d1 = dim(codomain(t1))
114109
d2 = dim(codomain(t2))
@@ -123,10 +118,10 @@ for V in spacelist
123118
end
124119
symmetricbraiding && @timedtestset "Tensor product: test via tensor contraction" begin
125120
for T in (Float32, ComplexF64)
126-
t1 = rand(T, V2 V3 V1)
127-
t2 = rand(T, V2 V1 V3)
121+
t1 = rand(T, V1, V5')
122+
t2 = rand(T, V2 V3, V4')
128123
t = @constinferred (t1 t2)
129-
@tensor t′[1, 2, 3, 4, 5, 6] := t1[1, 2, 3] * t2[4, 5, 6]
124+
@tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5]
130125
@test t t′
131126
end
132127
end
@@ -161,8 +156,8 @@ end
161156
V1, V2, V3, V4, V5 = Vlist1
162157
W1, W2, W3, W4, W5 = Vlist2
163158
for T in (Float32, ComplexF64)
164-
t1 = rand(T, V1 V2, V3' V4)
165-
t2 = rand(T, W2, W1 W1')
159+
t1 = rand(T, V2 V3, (V4 V5)')
160+
t2 = rand(T, W2, (W3 W4)')
166161
t = @constinferred (t1 t2)
167162
d1 = dim(codomain(t1))
168163
d2 = dim(codomain(t2))

0 commit comments

Comments
 (0)