-
Notifications
You must be signed in to change notification settings - Fork 70
Expand file tree
/
Copy pathTensorOperationsCUDACoreExt.jl
More file actions
57 lines (48 loc) · 1.81 KB
/
TensorOperationsCUDACoreExt.jl
File metadata and controls
57 lines (48 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
module TensorOperationsCUDACoreExt
using CUDACore
using TensorOperations
using TensorOperations: TensorOperations as TO
#-------------------------------------------------------------------------------------------
# Allocator
#-------------------------------------------------------------------------------------------
TO.tensoradd_type(TC, A::CuArray, pA::Index2Tuple, conjA::Bool) =
CuArray{TC, TO.numind(pA)}
function TO.CUDAAllocator()
Mout = CUDACore.UnifiedMemory
Min = CUDACore.default_memory
Mtemp = CUDACore.default_memory
return TO.CUDAAllocator{Mout, Min, Mtemp}()
end
function TO.tensoralloc_add(
TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
istemp::Val, allocator::TO.CUDAAllocator
)
ttype = CuArray{TC, TO.numind(pA)}
structure = TO.tensoradd_structure(A, pA, conjA)
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
end
function TO.tensoralloc_contract(
TC,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
istemp::Val, allocator::TO.CUDAAllocator
)
ttype = CuArray{TC, TO.numind(pAB)}
structure = TO.tensorcontract_structure(A, pA, conjA, B, pB, conjB, pAB)
return TO.tensoralloc(ttype, structure, istemp, allocator)::ttype
end
# NOTE: the general implementation in the `DefaultAllocator` case works just fine, without
# selecting an explicit memory model
function TO.tensoralloc(
::Type{CuArray{T, N}}, structure,
::Val{istemp}, allocator::TO.CUDAAllocator{Mout, Min, Mtemp}
) where {T, N, istemp, Mout, Min, Mtemp}
M = istemp ? Mtemp : Mout
return CuArray{T, N, M}(undef, structure)
end
function TO.tensorfree!(C::CuArray, ::TO.CUDAAllocator)
CUDACore.unsafe_free!(C)
return nothing
end
end