Skip to content

Commit 5726be9

Browse files
committed
Use approx for some GPU permutes
1 parent cdac60b commit 5726be9

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

test/cuda/tensors.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,12 @@ for V in spacelist
304304
end
305305
end
306306

307-
CUDA.@allowscalar begin
308-
t3 = @constinferred repartition(t, $k)
309-
t3 = repartition(t, k)
310-
@test norm(t3) norm(t)
311-
t3′ = @constinferred repartition!(similar(t3), t′)
312-
@test norm(t3′) norm(t′)
313-
@test dot(t′, t) dot(t3′, t3)
314-
end
307+
t3 = @constinferred repartition(t, $k)
308+
t3 = repartition(t, k)
309+
@test norm(t3) norm(t)
310+
t3′ = @constinferred repartition!(similar(t3), t′)
311+
@test norm(t3′) norm(t′)
312+
@test dot(t′, t) dot(t3′, t3)
315313
end
316314
end
317315
if BraidingStyle(I) isa SymmetricBraiding
@@ -322,9 +320,10 @@ for V in spacelist
322320
for p in permutations(1:5)
323321
p1 = ntuple(n -> p[n], k)
324322
p2 = ntuple(n -> p[k + n], 5 - k)
325-
dt2 = CUDA.@allowscalar permute(t, (p1, p2))
326323
ht2 = permute(TensorKit.to_cpu(t), (p1, p2))
327-
@test ht2 == TensorKit.to_cpu(dt2)
324+
dt2 = permute(t, (p1, p2))
325+
ht2′ = TensorKit.to_cpu(dt2)
326+
@test ht2 ht2′
328327
end
329328

330329
dt3 = CUDA.@allowscalar repartition(t, k)
@@ -380,7 +379,7 @@ for V in spacelist
380379
end
381380
@test ta tb
382381
end
383-
#=if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
382+
if BraidingStyle(I) isa Bosonic && hasfusiontensor(I)
384383
@timedtestset "Tensor contraction: test via CPU" begin
385384
dA1 = CUDA.randn(ComplexF64, V1' * V2', V3')
386385
dA2 = CUDA.randn(ComplexF64, V3 * V4, V5)
@@ -395,7 +394,7 @@ for V in spacelist
395394
TensorKit.to_cpu(dH)[s1, s2, t1, t2]
396395
@test TensorKit.to_cpu(dHrA12) hHrA12
397396
end
398-
end=# # doesn't yet work because of AdjointTensor
397+
end
399398
@timedtestset "Index flipping: test flipping inverse" begin
400399
t = CUDA.rand(ComplexF64, V1 V1' V1' V1)
401400
for i in 1:4

0 commit comments

Comments
 (0)