Skip to content

Commit a4bff32

Browse files
committed
Updates for cuTENSOR
1 parent d459349 commit a4bff32

2 files changed

Lines changed: 17 additions & 20 deletions

File tree

Project.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,24 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1818
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919

2020
[weakdeps]
21-
#AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
21+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2222
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2323
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2424
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2525
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2626

2727
[extensions]
28-
#TensorKitAMDGPUExt = "AMDGPU"
28+
TensorKitAMDGPUExt = "AMDGPU"
2929
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3030
TensorKitChainRulesCoreExt = "ChainRulesCore"
3131
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3232

3333
[compat]
34-
#AMDGPU = "2"
34+
AMDGPU = "2"
3535
Adapt = "4"
3636
Aqua = "0.6, 0.7, 0.8"
3737
CUDA = "5"
38-
cuTENSOR = "2.2"
38+
cuTENSOR = "2"
3939
ChainRulesCore = "1"
4040
ChainRulesTestUtils = "1"
4141
Combinatorics = "1"
@@ -79,3 +79,4 @@ test = ["Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "LinearAlgebra", "
7979
CUDA = {url = "https://github.com/JuliaGPU/CUDA.jl", rev = "master"}
8080
cuTENSOR = {url = "https://github.com/JuliaGPU/CUDA.jl", subdir="lib/cutensor", rev = "ksh/cutensor_bump"}
8181
MatrixAlgebraKit = {url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl", rev = "ksh/tk"}
82+
TensorOperations = {url = "https://github.com/QuantumKitHub/TensorOperations.jl", rev = "ksh/cutensor_bump"}

test/cuda.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ for V in spacelist
9797
@test norm(t + t, p) 2 * norm(t, p)
9898
@test norm(t) norm(t')
9999

100-
t2 = @constinferred rand!(similar(t))
100+
t2 = @constinferred CUDA.rand!(similar(t))
101101
β = rand(T)
102102
@test @constinferred(dot* t2, α * t)) conj(β) * α * conj(dot(t, t2))
103103
@test dot(t2, t) conj(dot(t, t2))
@@ -123,7 +123,7 @@ for V in spacelist
123123
W = V1 V2 V3 V4 V5
124124
for T in (Float32, ComplexF64)
125125
t = CUDA.rand(T, W)
126-
t2 = @constinferred rand!(similar(t))
126+
t2 = @constinferred CUDA.rand!(similar(t))
127127
@test norm(t, 2) norm(ad(t), 2)
128128
@test dot(t2, t) dot(ad(t2), ad(t))
129129
α = rand(T)
@@ -160,27 +160,23 @@ for V in spacelist
160160
@test LinearAlgebra.isdiag(D)
161161
@test LinearAlgebra.diag(D) == d
162162
end=#
163-
#=
164163
@timedtestset "Permutations: test via inner product invariance" begin
165164
W = V1 V2 V3 V4 V5
166165
t = CUDA.rand(ComplexF64, W)
167-
t′ = randn!(similar(t))
166+
t′ = CUDA.randn!(similar(t))
168167
for k in 0:5
169168
for p in permutations(1:5)
170169
p1 = ntuple(n -> p[n], k)
171170
p2 = ntuple(n -> p[k + n], 5 - k)
172-
# TODO fix me
173-
# t2 = @constinferred permute(t, (p1, p2))
171+
t2 = @constinferred permute(t, (p1, p2))
174172
t2 = permute(t, (p1, p2))
175173
@test norm(t2) norm(t)
176174
t2′ = permute(t′, (p1, p2))
177175
@test dot(t2′, t2) dot(t′, t) dot(transpose(t2′), transpose(t2))
178176
end
179177

180-
# TODO fix me
181-
#t3 = VERSION < v"1.7" ? repartition(t, k) :
182-
# @constinferred repartition(t, $k)
183-
t3 = repartition(t, k)
178+
t3 = @constinferred repartition(t, $k)
179+
t3 = repartition(t, k)
184180
@test norm(t3) norm(t)
185181
t3′ = @constinferred repartition!(similar(t3), t′)
186182
@test norm(t3′) norm(t′)
@@ -275,7 +271,7 @@ for V in spacelist
275271
for T in (Float64, ComplexF64)
276272
t1 = CUDA.rand(T, W1, W1)
277273
t2 = CUDA.rand(T, W2, W2)
278-
t = CUDA.rand(T, W1, W2)
274+
t = CUDA.rand(T, W1, W2)
279275
@test t1 * (t1 \ t) t
280276
@test (t / t2) * t2 t
281277
@test t1 \ one(t1) inv(t1)
@@ -295,12 +291,12 @@ for V in spacelist
295291
for T in (Float32, Float64, ComplexF32, ComplexF64)
296292
t1 = CUDA.rand(T, W1, W1)
297293
t2 = CUDA.rand(T, W2, W2)
298-
t = CUDA.rand(T, W1, W2)
294+
t = CUDA.rand(T, W1, W2)
299295
d1 = dim(W1)
300296
d2 = dim(W2)
301297
At1 = reshape(convert(Array, t1), d1, d1)
302298
At2 = reshape(convert(Array, t2), d2, d2)
303-
At = reshape(convert(Array, t), d1, d2)
299+
At = reshape(convert(Array, t), d1, d2)
304300
@test ad(t1 * t) ad(t1) * ad(t)
305301
@test ad(t1' * t) ad(t1)' * ad(t)
306302
@test ad(t2 * t') ad(t2) * ad(t)'
@@ -324,7 +320,7 @@ for V in spacelist
324320
@test ad(t1' / t') ad(t1)' / ad(t)'
325321
end
326322
end
327-
end=#
323+
end
328324
@timedtestset "Factorization" begin
329325
W = V1 V2 V3 V4 V5
330326
for T in (Float32, ComplexF64)
@@ -539,7 +535,7 @@ for V in spacelist
539535
# end
540536
#
541537
# TODO
542-
#=@timedtestset "Tensor product: test via norm preservation" begin
538+
@timedtestset "Tensor product: test via norm preservation" begin
543539
for T in (Float32, ComplexF64)
544540
t1 = CUDA.rand(T, V2 V3 V1, V1 V2)
545541
t2 = CUDA.rand(T, V2 V1 V3, V1 V1)
@@ -570,7 +566,7 @@ for V in spacelist
570566
@tensor t′[1, 2, 3, 4, 5, 6] := t1[1, 2, 3] * t2[4, 5, 6]
571567
@test t t′
572568
end
573-
end=#
569+
end
574570
end
575571
end
576572

0 commit comments

Comments
 (0)