Skip to content

Commit de03dc2

Browse files
lkdvoskshyatt
authored andcommitted
Add AMDGPU allocator support
1 parent d8458fb commit de03dc2

3 files changed

Lines changed: 58 additions & 1 deletion

File tree

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1717
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1818

1919
[weakdeps]
20+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2021
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
2122
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2223
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"
@@ -26,6 +27,7 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2627
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2728

2829
[extensions]
30+
TensorOperationsAMDGPUExt = "AMDGPU"
2931
TensorOperationsBumperExt = "Bumper"
3032
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
3133
TensorOperationsMooncakeExt = "Mooncake"
@@ -66,7 +68,6 @@ julia = "1.10"
6668

6769
[extras]
6870
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
69-
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
7071
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7172
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
7273
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"

ext/TensorOperationsAMDGPUExt.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
module TensorOperationsAMDGPUExt
2+
3+
using AMDGPU
4+
using TensorOperations
5+
using TensorOperations: TensorOperations as TO
6+
7+
#-------------------------------------------------------------------------------------------
8+
# Allocator
9+
#-------------------------------------------------------------------------------------------
10+
11+
TO.tensoradd_type(TC, A::AnyRocArray, pA::Index2Tuple, conjA::Bool) =
12+
ROCArray{TC, TO.numind(pA)}
13+
14+
function TO.tensoralloc_add(
15+
TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
16+
istemp::Val, allocator::TO.AMDAllocator
17+
)
18+
ttype = ROCArray{TC, TO.numind(pA)}
19+
structure = TO.tensoradd_structure(A, pA, conjA)
20+
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
21+
end
22+
23+
function TO.tensoralloc_contract(
24+
TC,
25+
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
26+
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
27+
pAB::Index2Tuple,
28+
istemp::Val, allocator::TO.AMDAllocator
29+
)
30+
ttype = ROCArray{TC, TO.numind(pAB)}
31+
structure = TO.tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
32+
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
33+
end
34+
35+
# NOTE: the general implementation in the `DefaultAllocator` case works just fine, without
36+
# selecting an explicit memory model
37+
function TO.tensoralloc(
38+
::Type{<:ROCArray{T, N}}, structure,
39+
::Val{istemp}, allocator::TO.AMDAllocator
40+
) where {T, N}
41+
return ROCArray{T, N}(undef, structure)
42+
end
43+
44+
function TO.tensorfree!(C::ROCArray, ::TO.AMDAllocator)
45+
AMDGPU.unsafe_free!(C)
46+
return nothing
47+
end
48+
49+
end

src/implementation/allocator.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ parameters `Min`, `Mout`, `Mtemp` can be any of the CUDA.jl memory types, i.e.
3030
"""
3131
struct CUDAAllocator{Mout, Min, Mtemp} end
3232

33+
"""
34+
AMDAllocator()
35+
36+
Allocator that uses the AMD memory manager and will thus allocate `ROCArray` instances.
37+
"""
38+
struct AMDAllocator end
39+
3340
"""
3441
ManualAllocator()
3542

0 commit comments

Comments
 (0)