Skip to content

Commit a3f4bfc

Browse files
committed
Test adjustment
1 parent b57ff80 commit a3f4bfc

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

test/cuda/tensors.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ for V in spacelist
1414
println("---------------------------------------")
1515
println("CUDA Tensors with symmetry: $Istr")
1616
println("---------------------------------------")
17+
hasbraiding = BraidingStyle(I) isa HasBraiding
1718
symmetricbraiding = BraidingStyle(I) isa SymmetricBraiding
1819
@timedtestset "Tensors with symmetry: $Istr" verbose = true begin
1920
V1, V2, V3, V4, V5 = V
@@ -359,14 +360,14 @@ for V in spacelist
359360
@test TensorKit.to_cpu(dHrA12) hHrA12
360361
end
361362
end
362-
BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin
363+
hasbraiding && @timedtestset "Index flipping: test flipping inverse" begin
363364
t = CUDA.rand(ComplexF64, V1 V1' V1' V1)
364365
for i in 1:4
365366
@test t flip(flip(t, i), i; inv = true)
366367
@test t flip(flip(t, i; inv = true), i)
367368
end
368369
end
369-
@timedtestset "Index flipping: test via explicit flip" begin
370+
symmetricbraiding && @timedtestset "Index flipping: test via explicit flip" begin
370371
t = CUDA.rand(ComplexF64, V1 V1' V1' V1)
371372
F1 = adapt(CuArray{ComplexF64}, unitary(flip(V1), V1))
372373

@@ -379,7 +380,7 @@ for V in spacelist
379380
@tensor tf[a, b; c, d] := conj(F1[d, d']) * t[a, b; c, d']
380381
@test twist!(flip(t, 4), 4) tf
381382
end
382-
@timedtestset "Index flipping: test via contraction" begin
383+
symmetricbraiding && @timedtestset "Index flipping: test via contraction" begin
383384
t1 = CUDA.rand(ComplexF64, V1 V2 V3 V4)
384385
t2 = CUDA.rand(ComplexF64, V2' V5 V4' V1)
385386
@tensor ta[a, b] := t1[x, y, a, z] * t2[y, b, z, x]
@@ -565,10 +566,10 @@ end
565566
d2 = dim(codomain(t2))
566567
d3 = dim(domain(t1))
567568
d4 = dim(domain(t2))
568-
At = convert(Array, t)
569+
At = convert(Array, adapt(Vector{T}, t))
569570
@test reshape(At, (d1, d2, d3, d4))
570-
reshape(convert(Array, t1), (d1, 1, d3, 1)) .*
571-
reshape(convert(Array, t2), (1, d2, 1, d4))
571+
reshape(convert(Array, adapt(Vector{T}, t1)), (d1, 1, d3, 1)) .*
572+
reshape(convert(Array, adapt(Vector{T}, t2)), (1, d2, 1, d4))
572573
end
573574
end
574575
end

0 commit comments

Comments
 (0)