-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathStridedGPUArraysExt.jl
More file actions
137 lines (109 loc) · 4.59 KB
/
StridedGPUArraysExt.jl
File metadata and controls
137 lines (109 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
module StridedGPUArraysExt
using Strided, GPUArrays, LinearAlgebra
import Strided: _gemm!
using GPUArrays: Adapt, KernelAbstractions
using GPUArrays.KernelAbstractions: @kernel, @index
using StridedViews: ParentIndex
ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)}
# StridedView backed by any GPU array type, with element type linked to the parent.
const GPUStridedView{T, N} = StridedView{T, N, <:AnyGPUArray{T}}
KernelAbstractions.get_backend(sv::GPUStridedView) = KernelAbstractions.get_backend(parent(sv))
# Conversion to CPU Array: materialise into a contiguous GPU array first (so the
# GPU-to-GPU copy! path is used), then let the GPU array type handle the transfer.
function Base.Array(a::GPUStridedView)
b = similar(parent(a), eltype(a), size(a))
copy!(StridedView(b), a)
return Array(b)
end
function Strided._gemm!(opA::Char, opB::Char, α, A::TA, B::TB, β, C::TC) where {TA <: GPUStridedView, TB <: GPUStridedView, TC <: GPUStridedView}
return GPUArrays.generic_matmatmul!(C, LinearAlgebra.wrap(A, opA), LinearAlgebra.wrap(B, opB), α, β)
end
# ---------- GPU mapreduce support ----------
@inline _gpu_init_acc(::Nothing, current_val) = current_val
@inline _gpu_init_acc(initop, current_val) = initop(current_val)
@inline _gpu_accum(::Nothing, acc, val) = val
@inline _gpu_accum(op, acc, val) = op(acc, val)
@inline function cartesian2parent(strides::NTuple{N, Int}, cidx::CartesianIndex{N}) where {N}
s = 0
for d in Base.OneTo(N)
@inbounds s += strides[d] * (cidx[d] - 1)
end
return s
end
@kernel function _mapreduce_gpu_kernel!(
f, op, initop,
dims_red, strides, offsets, ops, arrays
)
I_out = @index(Global, Cartesian)
# Compute parent index for current index.
Is_parent = cartesian2parent.(strides, Ref(I_out)) .+ offsets .+ 1
# Initialize accumulator from current output value (or apply initop)
out = arrays[1]
out_I_parent = Is_parent[1]
@inbounds acc = _gpu_init_acc(initop, ops[1](out[out_I_parent]))
inputs = Base.tail(arrays)
inputs_I_parent = Base.tail(Is_parent)
inputs_strides = Base.tail(strides)
inputs_ops = Base.tail(ops)
for I_red in CartesianIndices(dims_red)
# Compute parent index for current reduction index
Is_red_parent = cartesian2parent.(inputs_strides, Ref(I_red))
Is_inputs = inputs_I_parent .+ Is_red_parent
# Get values from each input array, apply map function, and accumulate
vals = map(inputs, inputs_ops, Is_inputs) do in, in_op, in_I
in_op(getindex(in, in_I))
end
acc = _gpu_accum(op, acc, f(vals...))
end
# Write back result to output array
@inbounds out[out_I_parent] = ops[1](acc)
end
# GPU-compatible _mapreduce: avoids scalar indexing (first(A), out[ParentIndex(1)])
# that JLArrays/real GPUs prohibit. Mirrors GPUArrays' neutral_element approach:
# infer output type via Broadcast machinery, look up the neutral element (errors on
# unknown ops), fill the output buffer, then read back a single scalar via Array().
function Strided._mapreduce(
f, op, A::GPUStridedView{T, N}, nt = nothing
) where {T, N}
if isempty(A)
b = Base.mapreduce_empty(f, op, T)
return nt === nothing ? b : op(b, nt.init)
end
dims = size(A)
if nt === nothing
ET = Base.Broadcast.combine_eltypes(f, (A,))
ET = Base.promote_op(op, ET, ET)
(ET === Union{} || ET === Any) &&
error("cannot infer output element type for mapreduce; pass an explicit `init`")
init = GPUArrays.neutral_element(op, ET)
else
ET = typeof(nt.init)
init = nt.init
end
out = similar(parent(A), ET, (1,))
fill!(out, init)
Strided._mapreducedim!(f, op, nothing, dims, (sreshape(StridedView(out), one.(dims)), A))
return Array(out)[1]
end
function Strided._mapreduce_block!(
f, op, initop,
dims::Dims{N},
strides, offsets, costs,
arrays::Tuple{GPUStridedView{TO, N}, Vararg{GPUStridedView{<:Any, N}}}
) where {TO, N}
out = arrays[1]
out_strides = strides[1]
# Number of output elements = product of non-reduction dims
dims_out = ntuple(Val(N)) do d
@inbounds iszero(out_strides[d]) ? 1 : dims[d]
end
dims_red = ntuple(Val(N)) do d
@inbounds iszero(out_strides[d]) ? dims[d] : 1
end
backend = KernelAbstractions.get_backend(parent(out))
kernel! = _mapreduce_gpu_kernel!(backend)
ops = getproperty.(arrays, :op)
kernel!(f, op, initop, dims_red, strides, offsets, ops, parent.(arrays); ndrange = dims_out)
return nothing
end
end