Skip to content

Commit 77a9008

Browse files
lkdvoskshyatt
andauthored
Add GPU support for Strided backends (#264)
* simplify strided implementations * add JLArrays extension * add CUDA extension * add testing infrastructure * Update Project.toml Strided tag has hit * incremental * Update test file --------- Co-authored-by: Katharine Hyatt <kshyatt@users.noreply.github.com> Co-authored-by: Katharine Hyatt <khyatt@flatironinstitute.org>
1 parent cb2039f commit 77a9008

8 files changed

Lines changed: 218 additions & 96 deletions

File tree

Project.toml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,32 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1919
[weakdeps]
2020
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
2121
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
22-
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
22+
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"
2323
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
24+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
25+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2426
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2527

2628
[extensions]
2729
TensorOperationsBumperExt = "Bumper"
2830
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
2931
TensorOperationsMooncakeExt = "Mooncake"
32+
TensorOperationsCUDACoreExt = "CUDACore"
3033
TensorOperationsEnzymeExt = "Enzyme"
3134
TensorOperationscuTENSORExt = "cuTENSOR"
35+
TensorOperationsJLArraysExt = "JLArrays"
3236

3337
[compat]
3438
Aqua = "0.6, 0.7, 0.8"
39+
Adapt = "4"
3540
Bumper = "0.6, 0.7"
41+
CUDACore = "6"
3642
ChainRulesCore = "1"
3743
ChainRulesTestUtils = "1"
3844
DynamicPolynomials = "0.5, 0.6"
3945
Enzyme = "0.13.115"
4046
EnzymeTestUtils = "0.2"
47+
JLArrays = "0.3"
4148
LRUCache = "1"
4249
LinearAlgebra = "1.6"
4350
Logging = "1.6"
@@ -47,7 +54,7 @@ PrecompileTools = "1.1"
4754
Preferences = "1.4"
4855
PtrArrays = "1.2"
4956
Random = "1"
50-
Strided = "2.4"
57+
Strided = "2.5"
5158
StridedViews = "0.5"
5259
Test = "1"
5360
TupleTools = "1.6"
@@ -57,8 +64,10 @@ cuTENSOR = "6"
5764
julia = "1.10"
5865

5966
[extras]
67+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
6068
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6169
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
70+
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"
6271
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6372
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
6473
cuRAND = "20fd9a0b-12d5-4c2f-a8af-7c34e9e60431"
@@ -72,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7281
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
7382

7483
[targets]
75-
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "cuRAND", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"]
84+
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "cuRAND", "CUDACore", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils", "Adapt", "JLArrays"]

ext/TensorOperationsCUDACoreExt.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
module TensorOperationsCUDACoreExt
2+
3+
using CUDACore
4+
using TensorOperations
5+
using TensorOperations: TensorOperations as TO
6+
7+
#-------------------------------------------------------------------------------------------
8+
# Allocator
9+
#-------------------------------------------------------------------------------------------
10+
11+
TO.tensoradd_type(TC, A::CuArray, pA::Index2Tuple, conjA::Bool) =
12+
CuArray{TC, TO.numind(pA)}
13+
14+
function TO.CUDAAllocator()
15+
Mout = CUDACore.UnifiedMemory
16+
Min = CUDACore.default_memory
17+
Mtemp = CUDACore.default_memory
18+
return TO.CUDAAllocator{Mout, Min, Mtemp}()
19+
end
20+
21+
function TO.tensoralloc_add(
22+
TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
23+
istemp::Val, allocator::TO.CUDAAllocator
24+
)
25+
ttype = CuArray{TC, TO.numind(pA)}
26+
structure = TO.tensoradd_structure(A, pA, conjA)
27+
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
28+
end
29+
30+
function TO.tensoralloc_contract(
31+
TC,
32+
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
33+
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
34+
pAB::Index2Tuple,
35+
istemp::Val, allocator::TO.CUDAAllocator
36+
)
37+
ttype = CuArray{TC, TO.numind(pAB)}
38+
structure = TO.tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
39+
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
40+
end
41+
42+
# NOTE: the general implementation in the `DefaultAllocator` case works just fine, without
43+
# selecting an explicit memory model
44+
function TO.tensoralloc(
45+
::Type{CuArray{T, N}}, structure,
46+
::Val{istemp}, allocator::TO.CUDAAllocator{Mout, Min, Mtemp}
47+
) where {T, N, istemp, Mout, Min, Mtemp}
48+
M = istemp ? Mtemp : Mout
49+
return CuArray{T, N, M}(undef, structure)
50+
end
51+
52+
function TO.tensorfree!(C::CuArray, ::TO.CUDAAllocator)
53+
CUDACore.unsafe_free!(C)
54+
return nothing
55+
end
56+
57+
end

ext/TensorOperationsJLArraysExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TensorOperationsJLArraysExt
2+
3+
using JLArrays
4+
using TensorOperations
5+
6+
TensorOperations.tensoradd_type(TC, A::JLArray, pA::Index2Tuple, conjA::Bool) =
7+
JLArray{TC, sum(length.(pA))}
8+
9+
end

ext/TensorOperationscuTENSORExt.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -144,57 +144,6 @@ function _custrided(
144144
end
145145
end
146146

147-
#-------------------------------------------------------------------------------------------
148-
# Allocator
149-
#-------------------------------------------------------------------------------------------
150-
function TO.CUDAAllocator()
151-
Mout = CUDACore.UnifiedMemory
152-
Min = CUDACore.default_memory
153-
Mtemp = CUDACore.default_memory
154-
return CUDAAllocator{Mout, Min, Mtemp}()
155-
end
156-
157-
function TO.tensoralloc_add(
158-
TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
159-
istemp::Val, allocator::CUDAAllocator
160-
)
161-
ttype = CuArray{TC, TO.numind(pA)}
162-
structure = TO.tensoradd_structure(A, pA, conjA)
163-
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
164-
end
165-
166-
function TO.tensoralloc_contract(
167-
TC,
168-
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
169-
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
170-
pAB::Index2Tuple,
171-
istemp::Val, allocator::CUDAAllocator
172-
)
173-
ttype = CuArray{TC, TO.numind(pAB)}
174-
structure = TO.tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
175-
return tensoralloc(ttype, structure, istemp, allocator)::ttype
176-
end
177-
178-
# Overwrite tensoradd_type
179-
function TO.tensoradd_type(TC, A::CuArray, pA::Index2Tuple, conjA::Bool)
180-
return CuArray{TC, sum(length.(pA))}
181-
end
182-
183-
# NOTE: the general implementation in the `DefaultAllocator` case works just fine, without
184-
# selecting an explicit memory model
185-
function TO.tensoralloc(
186-
::Type{CuArray{T, N}}, structure,
187-
::Val{istemp}, allocator::CUDAAllocator{Mout, Min, Mtemp}
188-
) where {T, N, istemp, Mout, Min, Mtemp}
189-
M = istemp ? Mtemp : Mout
190-
return CuArray{T, N, M}(undef, structure)
191-
end
192-
193-
function TO.tensorfree!(C::CuArray, ::CUDAAllocator)
194-
CUDACore.unsafe_free!(C)
195-
return nothing
196-
end
197-
198147
#-------------------------------------------------------------------------------------------
199148
# Implementation
200149
#-------------------------------------------------------------------------------------------

src/implementation/allocator.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ end
166166
function tensoradd_type(TC, A::Base.PermutedDimsArray, pA::Index2Tuple, conjA::Bool)
167167
return tensoradd_type(TC, A.parent, pA, conjA)
168168
end
169+
function tensoradd_type(TC, A::StridedView, pA::Index2Tuple, conjA::Bool)
170+
return tensoradd_type(TC, parent(A), pA, conjA)
171+
end
169172

170173
function tensoradd_structure(A::AbstractArray, pA::Index2Tuple, conjA::Bool)
171174
return size.(Ref(A), linearize(pA))

src/implementation/strided.jl

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,12 @@
11
const StridedViewOrDiagonal = Union{StridedView, Diagonal}
22

3-
_ishostarray(x::StridedView) = (pointer(x) isa Ptr)
4-
_ishostarray(x::Diagonal) = (pointer(x.diag) isa Ptr)
3+
select_backend(::typeof(tensoradd!), C::StridedView, A::StridedView) = StridedNative()
4+
select_backend(::typeof(tensortrace!), C::StridedView, A::StridedView) = StridedNative()
55

6-
function select_backend(::typeof(tensoradd!), C::StridedView, A::StridedView)
7-
if _ishostarray(C) && _ishostarray(A)
8-
return StridedNative()
9-
else
10-
return NoBackend()
11-
end
12-
end
13-
function select_backend(::typeof(tensortrace!), C::StridedView, A::StridedView)
14-
if _ishostarray(C) && _ishostarray(A)
15-
return StridedNative()
16-
else
17-
return NoBackend()
18-
end
19-
end
20-
21-
function select_backend(
22-
::typeof(tensorcontract!), C::StridedView, A::StridedView, B::StridedView
23-
)
24-
if _ishostarray(C) && _ishostarray(A) && _ishostarray(B)
25-
return eltype(C) <: LinearAlgebra.BlasFloat ? StridedBLAS() : StridedNative()
26-
else
27-
return NoBackend()
28-
end
29-
end
30-
function select_backend(
31-
::typeof(tensorcontract!), C::StridedViewOrDiagonal,
32-
A::StridedViewOrDiagonal, B::StridedViewOrDiagonal
33-
)
34-
if _ishostarray(C) && _ishostarray(A) && _ishostarray(B)
35-
return StridedNative()
36-
else
37-
return NoBackend()
38-
end
39-
end
6+
select_backend(::typeof(tensorcontract!), C::StridedView, A::StridedView, B::StridedView) =
7+
eltype(C) <: LinearAlgebra.BlasFloat ? StridedBLAS() : StridedNative()
8+
select_backend(::typeof(tensorcontract!), C::StridedViewOrDiagonal, A::StridedViewOrDiagonal, B::StridedViewOrDiagonal) =
9+
StridedNative()
4010

4111
#-------------------------------------------------------------------------------------------
4212
# Force strided implementation on AbstractArray instances with Strided backend

test/gpu.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
using TensorOperations
2+
using TensorOperations: StridedBLAS, StridedNative, linearize, numout
3+
using Test
4+
using Adapt
5+
using TupleTools
6+
using JLArrays
7+
using VectorInterface
8+
using CUDACore
9+
10+
test_result(a::AbstractArray, b::AbstractArray; kwargs...) =
11+
isapprox(collect(a), collect(b); kwargs...)
12+
13+
function compare(f, AT::Type, xs...; kwargs...)
14+
cpu_in = map(deepcopy, xs) # copy on CPU
15+
gpu_in = map(adapt(AT), xs) # adapt on GPU
16+
17+
cpu_out = f(cpu_in...)
18+
gpu_out = f(gpu_in...)
19+
20+
return test_result(cpu_out, gpu_out; kwargs...)
21+
end
22+
23+
# types to test for
24+
ATs = []
25+
!is_buildkite && push!(ATs, JLArray)
26+
CUDACore.functional() && push!(ATs, CuArray)
27+
28+
backends = [StridedBLAS(), StridedNative()]
29+
30+
@testset "tensoradd! ($AT)" for AT in ATs
31+
sz = (3, 5, 4, 6)
32+
p = (3, 1, 4, 2)
33+
for backend in backends, T in (Float32, ComplexF32)
34+
A = randn(T, sz)
35+
C = randn(T, TupleTools.getindices(sz, p))
36+
37+
@test compare(AT, C, A) do c, a
38+
tensoradd!(c, a, (p, ()), false, One(), Zero(), backend)
39+
end
40+
41+
α = rand(T)
42+
@test compare(AT, C, A) do c, a
43+
tensoradd!(c, a, (p, ()), false, α, Zero(), backend)
44+
end
45+
46+
β = rand(T)
47+
@test compare(AT, C, A) do c, a
48+
tensoradd!(c, a, (p, ()), false, α, β, backend)
49+
end
50+
51+
T <: Real || @test compare(AT, C, A) do c, a
52+
tensoradd!(c, a, (p, ()), true, α, β, backend)
53+
end
54+
end
55+
end
56+
57+
@testset "tensortrace! ($AT)" for AT in ATs
58+
sz = (2, 4, 3, 2)
59+
p = (2, 3)
60+
q = ((1,), (4,))
61+
62+
for backend in backends, T in (Float32, ComplexF32)
63+
A = randn(T, sz)
64+
C = randn(T, TupleTools.getindices(sz, p))
65+
66+
@test compare(AT, C, A) do c, a
67+
tensortrace!(c, a, (p, ()), q, false, One(), Zero(), backend)
68+
end
69+
70+
α = rand(T)
71+
@test compare(AT, C, A) do c, a
72+
tensortrace!(c, a, (p, ()), q, false, α, Zero(), backend)
73+
end
74+
75+
β = rand(T)
76+
@test compare(AT, C, A) do c, a
77+
tensortrace!(c, a, (p, ()), q, false, α, β, backend)
78+
end
79+
80+
T <: Real || @test compare(AT, C, A) do c, a
81+
tensortrace!(c, a, (p, ()), q, true, α, β, backend)
82+
end
83+
end
84+
end
85+
86+
@testset "tensorcontract! ($AT)" for AT in ATs
87+
sz = (2, 4, 3, 4, 2, 5)
88+
pA = ((4, 1), (2, 3))
89+
pB = ((3, 1), (2,))
90+
pAB = ((1, 2, 3), ())
91+
92+
for backend in backends, T in (Float32, ComplexF32)
93+
A = randn(T, (2, 4, 3, 2))
94+
B = randn(T, (3, 3, 4))
95+
C = randn(T, (2, 2, 3))
96+
97+
@test compare(AT, C, A, B) do c, a, b
98+
tensorcontract!(c, a, pA, false, b, pB, false, pAB, One(), Zero(), backend)
99+
end
100+
101+
α = rand(T)
102+
@test compare(AT, C, A, B) do c, a, b
103+
tensorcontract!(c, a, pA, false, b, pB, false, pAB, α, Zero(), backend)
104+
end
105+
106+
β = rand(T)
107+
@test compare(AT, C, A, B) do c, a, b
108+
tensorcontract!(c, a, pA, false, b, pB, false, pAB, α, β, backend)
109+
end
110+
111+
if !(T <: Real)
112+
@test compare(AT, C, A, B) do c, a, b
113+
tensorcontract!(c, a, pA, true, b, pB, false, pAB, α, β, backend)
114+
end
115+
@test compare(AT, C, A, B) do c, a, b
116+
tensorcontract!(c, a, pA, false, b, pB, true, pAB, α, β, backend)
117+
end
118+
@test compare(AT, C, A, B) do c, a, b
119+
tensorcontract!(c, a, pA, true, b, pB, true, pAB, α, β, backend)
120+
end
121+
end
122+
end
123+
124+
end

test/runtests.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ if !is_buildkite
5050
end
5151
end
5252

53-
if is_buildkite
54-
# note: cuTENSOR should not be loaded before this point
55-
# as there is a test which requires it to be loaded after
56-
@testset "cuTENSOR extension" verbose = true begin
57-
include("cutensor.jl")
58-
end
53+
# note: cuTENSOR should not be loaded before this point
54+
# as there is a test which requires it to be loaded after
55+
@testset "cuTENSOR extension" verbose = true begin
56+
include("cutensor.jl")
57+
end
58+
@testset "GPUArrays" verbose = true begin
59+
include("gpu.jl")
5960
end
6061

6162
if !is_buildkite

0 commit comments

Comments
 (0)