-
Notifications
You must be signed in to change notification settings - Fork 97
Expand file tree
/
Copy pathChainRulesKernelAbstractionsExt.jl
More file actions
45 lines (37 loc) · 1.24 KB
/
ChainRulesKernelAbstractionsExt.jl
File metadata and controls
45 lines (37 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
module ChainRulesKernelAbstractionsExt
import Adapt
import Atomix
import ChainRules
import GPUArrays
import KernelAbstractions as KA
using GPUArraysCore: AbstractGPUArray
using KernelAbstractions
function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...)
kab = get_backend(dx)
if KA.supports_atomics(kab)
gids = GPUArrays.to_indices(dx, inds)
idims = map(length, gids)
Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
else
dx_cpu = Adapt.adapt(Array, dx)
view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy)
copyto!(dx, dx_cpu)
end
return dx
end
@kernel function scatter!(op, dest, src, idims, Is::Vararg{Any, N}) where N
_scatter!(@index(Global), op, dest, src, idims, Is...)
end
@generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N
quote
is = @inbounds CartesianIndices(idims)[i]
dv = src[i]
Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]])
Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j
end
end
function _accum!(op, dest, val, ids...)
Atomix.modify!(Atomix.IndexableRef(dest, (ids...,)), op, val)
end
end