Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.9.34"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
14 changes: 14 additions & 0 deletions ext/NNlibCUDAExt/NNlibCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@ module NNlibCUDAExt

using NNlib
using CUDA
import CUDA.CUSPARSE:
AbstractCuSparseVector,
CuSparseMatrixCSC,
CuSparseMatrixCSR,
CuSparseMatrixBSR,
CuSparseMatrixCOO
using Random, Statistics

const AbstractCuSparseArray{Tv,Ti} = Union{
AbstractCuSparseVector{Tv,Ti},
CuSparseMatrixCSC{Tv,Ti},
CuSparseMatrixCSR{Tv,Ti},
CuSparseMatrixBSR{Tv,Ti},
CuSparseMatrixCOO{Tv,Ti},
}

include("sampling.jl")
include("activations.jl")
include("batchedadjtrans.jl")
Expand Down
27 changes: 8 additions & 19 deletions ext/NNlibCUDAExt/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
# supported op: +, -, *, /, max, min, &, |, mean

## TODO support sparse dst/src/idx
## See issue https://github.com/FluxML/NNlib.jl/issues/647
# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector
# const AnyCuSparseMatrix{Tv,Ti} = Union{
# AbstractCuSparseMatrix{Tv,Ti},
# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix
# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX
# CUDA.CuSparseMatrixCOO{Tv,Ti},
# }
# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}}

function scatter_kernel!(op::OP, dst, src, idx) where OP
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x

Expand Down Expand Up @@ -56,9 +45,8 @@ end


function NNlib.scatter!(op::OP, dst::AnyCuArray,
src::AnyCuArray,
idx::AnyCuArray) where OP
isempty(idx) && return dst
src::Union{AnyCuArray,AbstractCuSparseArray},
idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP
dims = NNlib.scatter_dims(dst, src, idx)
args = if dims == 0
max_idx = length(idx)
Expand All @@ -79,8 +67,8 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray,
end

function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray,
src::AnyCuArray,
idx::AnyCuArray)
src::Union{AnyCuArray,AbstractCuSparseArray},
idx::Union{AnyCuArray,AbstractCuSparseArray})
Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
dst_ = NNlib.scatter!(+, zero(dst), src, idx)
dst .+= NNlib.safe_div.(dst_, Ns)
Expand Down Expand Up @@ -177,10 +165,11 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca
end

function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst,
src::AnyCuArray,
idx::AnyCuArray)
src::Union{AnyCuArray,AbstractCuSparseArray},
idx::Union{AnyCuArray,AbstractCuSparseArray})
dims = ndims(src) - ndims(idx)
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src)
gathered_src = CuArray(src) # Convert to dense to avoid type-unstable broadcast if sparse
Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), gathered_src)
rev_idx = NNlib.reverse_indices(idx)
rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx))

Expand Down
2 changes: 2 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using Base.Broadcast: broadcasted
using Base.Threads
using ChainRulesCore
using GPUArraysCore
using GPUArrays
using GPUArrays: AbstractGPUSparseArray
using KernelAbstractions
using KernelAbstractions: @atomic
using LinearAlgebra
Expand Down
2 changes: 1 addition & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
return dst
end

function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray)
function gather!(dst::AnyGPUArray, src::Union{AnyGPUArray, AbstractGPUSparseArray}, idx::AnyGPUArray)
isempty(dst) && return dst
n_dims = scatter_dims(src, dst, idx)
dims = size(src)[1:n_dims]
Expand Down
16 changes: 15 additions & 1 deletion test/ext_cuda/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,24 @@ using Zygote
using ForwardDiff: Dual
using Statistics: mean
using CUDA, cuDNN
import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO
using SparseArrays
import CUDA.CUSPARSE:
AbstractCuSparseVector,
CuSparseMatrixCSC,
CuSparseMatrixCSR,
CuSparseMatrixBSR,
CuSparseMatrixCOO
using NNlib: batchnorm, ∇batchnorm
CUDA.allowscalar(false)

const AbstractCuSparseArray{Tv,Ti} = Union{
AbstractCuSparseVector{Tv,Ti},
CuSparseMatrixCSC{Tv,Ti},
CuSparseMatrixCSR{Tv,Ti},
CuSparseMatrixBSR{Tv,Ti},
CuSparseMatrixCOO{Tv,Ti},
}

include("test_utils.jl")
include("activations.jl")
include("dropout.jl")
Expand Down
125 changes: 95 additions & 30 deletions test/ext_cuda/scatter.jl
Original file line number Diff line number Diff line change
@@ -1,104 +1,169 @@
dsts = Dict(
0 => cu([3, 4, 5, 6, 7]),
1 => cu([3 3 4 4 5;
5 5 6 6 7]),
5 5 6 6 7]),
)
srcs = Dict(
(0, true) => cu(ones(Int, 3, 4)),
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
(1, true) => cu(ones(Int, 2, 3, 4)),
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1, 3, 4)),
)
idxs = [
cu([1 2 3 4;
4 2 1 3;
3 5 5 3]), # integer index
cu([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]), # tuple index
cu(CartesianIndex.([(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)])), # CartesianIndex index
(3,) (5,) (5,) (3,)
]), # tuple index
cu(CartesianIndex.([
(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)
])), # CartesianIndex index
]

types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}]


@testset "scatter" begin
for T = types
for T in types
@testset "$(T)" begin
@testset "+" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end

@testset "-" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end

@testset "max" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end

@testset "min" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end
end
end


# Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays.
srcs_sp(T, fmt) = sparse(cu(ones(T, 3, 4)), fmt = fmt)
num_types = [Float32, Float64]

@testset "scatter sparse-specialized" begin
for T in num_types
@testset "+" begin
# Dims is implicitly 0. No sparse equivant for multidimensional src
for idx in idxs, fmt in [:csc, :csr, :bsr, :coo]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(+, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

@testset "-" begin
for idx in idxs, fmt in [:csc, :csr, :bsr, :coo]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(-, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

@testset "max" begin
for idx in idxs, fmt in [:csc, :csr]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(max, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

@testset "min" begin
for idx in idxs, fmt in [:csc, :csr]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(min, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

# Sparse-specialized for operations not tested on eltype <: Integer
@testset "$(T)" begin
# Dims is implicitly 0. No sparse equivant for multidimensional src/dst
@testset "*" begin
for idx in idxs, fmt in [:csc, :csr, :bsr, :coo]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(*, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

@testset "/" begin
for idx in idxs, fmt in [:csc, :csr]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(/, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end

@testset "mean" begin
for idx in idxs, fmt in [:csc, :csr]
# mutated = true
gputest((dst, src, idx) -> NNlib.scatter!(mean, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx)
end
end
end
end
end


for T = [CuArray{Float32}, CuArray{Float64}]
for T in [CuArray{Float32}, CuArray{Float64}]
@testset "$(T)" begin
@testset "*" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end

@testset "/" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end

@testset "mean" begin
for idx = idxs, dims = [0, 1]
for idx in idxs, dims in [0, 1]
mutated = true
gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true)
gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true)

mutated = false
gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true)
gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad = true)
end
end
end
Expand Down
23 changes: 23 additions & 0 deletions test/ext_cuda/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,26 @@ function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false,
end
end
end

function gputest(f, xs::Vararg{Union{AnyCuArray,AbstractCuSparseArray}}; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...)
gpu_in = xs
cpu_in = map(CuArray, xs)

cpu_out = f(cpu_in...; kws...)
gpu_out = f(gpu_in...; kws...) |> collect
@test collect(cpu_out) ≈ gpu_out rtol=rtol atol=atol broken=broken

if checkgrad
# use mean instead of sum to prevent error accumulation (for larger
# tensors) which causes error to go above atol
cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...)
gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...)
for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)
if cpu_g === nothing
@test gpu_g === nothing
else
@test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad
end
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ end
using CUDA
if CUDA.functional()
@testset "CUDA" begin
nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather")))
nnlib_testsuite(CUDABackend; skip_tests=Set(("Gather")))

include("ext_cuda/runtests.jl")
end
Expand Down
Loading