-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathallocators.jl
More file actions
73 lines (57 loc) · 3.33 KB
/
allocators.jl
File metadata and controls
73 lines (57 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
##
const ZEROS_DOC = """
@zeros(args...)
Call `zeros(numbertype, args...)`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the function `zeros` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (zeros for Threads and CUDA.zeros for CUDA).
"""
@doc ZEROS_DOC
macro zeros(args...) check_initialized(); esc(_zeros(args...)); end
##
const ONES_DOC = """
@ones(args...)
Call `ones(numbertype, args...)`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the function `ones` is chosen to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref) (ones for Threads CUDA.ones for CUDA).
"""
@doc ONES_DOC
macro ones(args...) check_initialized(); esc(_ones(args...)); end
##
const RAND_DOC = """
@rand(args...)
Call `rand(numbertype, args...)`, where `numbertype` is the datatype selected with [`@init_parallel_kernel`](@ref) and the function `rand` is chosen/implemented to be compatible with the package for parallelization selected with [`@init_parallel_kernel`](@ref).
"""
@doc RAND_DOC
macro rand(args...) check_initialized(); esc(_rand(args...)); end
## MACROS FORCING PACKAGE, IGNORING INITIALIZATION
macro zeros_cuda(args...) check_initialized(); esc(_zeros(args...; package=PKG_CUDA)); end
macro ones_cuda(args...) check_initialized(); esc(_ones(args...; package=PKG_CUDA)); end
macro rand_cuda(args...) check_initialized(); esc(_rand(args...; package=PKG_CUDA)); end
macro zeros_threads(args...) check_initialized(); esc(_zeros(args...; package=PKG_THREADS)); end
macro ones_threads(args...) check_initialized(); esc(_ones(args...; package=PKG_THREADS)); end
macro rand_threads(args...) check_initialized(); esc(_rand(args...; package=PKG_THREADS)); end
## ALLOCATOR FUNCTIONS
function _zeros(args...; package::Symbol=get_package())
numbertype = get_numbertype()
if (package == PKG_CUDA) return :(CUDA.zeros($numbertype, $(args...)))
elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.zero, $numbertype, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
end
function _ones(args...; package::Symbol=get_package())
numbertype = get_numbertype()
if (package == PKG_CUDA) return :(CUDA.ones($numbertype, $(args...)))
elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.one, $numbertype, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
end
function _rand(args...; package::Symbol=get_package())
numbertype = get_numbertype()
if (package == PKG_CUDA) return :(CUDA.CuArray(rand($numbertype, $(args...))))
elseif (package == PKG_THREADS) return :(ParallelStencil.ParallelKernel._parallel_init(Base.rand, $numbertype, $(args...)))
else @KeywordArgumentError("$ERRMSG_UNSUPPORTED_PACKAGE (obtained: $package).")
end
end
function _parallel_init(f::F, numbertype::Type{T}, args...) where {F, T}
arr = Array{numbertype, length(args)}(undef, args...)
Threads.@threads :static for i in eachindex(arr)
@inbounds arr[i] = f(T)
end
return arr
end