diff --git a/Project.toml b/Project.toml index 1f02c3a0..dd487b92 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.34" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/NNlibCUDAExt/NNlibCUDAExt.jl b/ext/NNlibCUDAExt/NNlibCUDAExt.jl index 7939b808..338d5605 100644 --- a/ext/NNlibCUDAExt/NNlibCUDAExt.jl +++ b/ext/NNlibCUDAExt/NNlibCUDAExt.jl @@ -2,8 +2,22 @@ module NNlibCUDAExt using NNlib using CUDA +import CUDA.CUSPARSE: + AbstractCuSparseVector, + CuSparseMatrixCSC, + CuSparseMatrixCSR, + CuSparseMatrixBSR, + CuSparseMatrixCOO using Random, Statistics +const AbstractCuSparseArray{Tv,Ti} = Union{ + AbstractCuSparseVector{Tv,Ti}, + CuSparseMatrixCSC{Tv,Ti}, + CuSparseMatrixCSR{Tv,Ti}, + CuSparseMatrixBSR{Tv,Ti}, + CuSparseMatrixCOO{Tv,Ti}, +} + include("sampling.jl") include("activations.jl") include("batchedadjtrans.jl") diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 04a63805..0b6dca0e 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,16 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean -## TODO support sparse dst/src/idx -## See issue https://github.com/FluxML/NNlib.jl/issues/647 -# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector -# const AnyCuSparseMatrix{Tv,Ti} = Union{ -# AbstractCuSparseMatrix{Tv,Ti}, -# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix -# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX -# CUDA.CuSparseMatrixCOO{Tv,Ti}, -# } -# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}} - function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -56,9 +45,8 @@ end function NNlib.scatter!(op::OP, dst::AnyCuArray, - src::AnyCuArray, - idx::AnyCuArray) where OP - isempty(idx) && return dst + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -79,8 +67,8 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, end function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, - src::AnyCuArray, - idx::AnyCuArray) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -177,10 +165,11 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray, - idx::AnyCuArray) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) dims = ndims(src) - ndims(idx) - Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + gathered_src = CuArray(src) # Convert to dense to avoid type-unstable broadcast if sparse + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), gathered_src) rev_idx = NNlib.reverse_indices(idx) rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) diff --git a/src/NNlib.jl b/src/NNlib.jl index 5ea90778..41b41560 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,6 +7,8 @@ using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore using GPUArraysCore +using GPUArrays +using GPUArrays: AbstractGPUSparseArray using KernelAbstractions using KernelAbstractions: @atomic using LinearAlgebra diff --git a/src/gather.jl b/src/gather.jl index 7997f878..8d9e20a0 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) return dst end -function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) +function gather!(dst::AnyGPUArray, src::Union{AnyGPUArray, AbstractGPUSparseArray}, idx::AnyGPUArray) isempty(dst) && return dst n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims] diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 1db0b6ee..e07c4185 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -4,10 +4,24 @@ using Zygote using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN -import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO +using SparseArrays +import CUDA.CUSPARSE: + AbstractCuSparseVector, + CuSparseMatrixCSC, + CuSparseMatrixCSR, + CuSparseMatrixBSR, + CuSparseMatrixCOO using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) +const AbstractCuSparseArray{Tv,Ti} = Union{ + AbstractCuSparseVector{Tv,Ti}, + CuSparseMatrixCSC{Tv,Ti}, + CuSparseMatrixCSR{Tv,Ti}, + CuSparseMatrixBSR{Tv,Ti}, + CuSparseMatrixCOO{Tv,Ti}, +} + include("test_utils.jl") include("activations.jl") include("dropout.jl") diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index a4977f28..f7070162 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -1,13 +1,13 @@ dsts = Dict( 0 => cu([3, 4, 5, 6, 7]), 1 => cu([3 3 4 4 5; - 5 5 6 6 7]), + 5 5 6 6 7]), ) srcs = Dict( (0, true) => cu(ones(Int, 3, 4)), (0, false) => cu(ones(Int, 3) * collect(1:4)'), (1, true) => cu(ones(Int, 2, 3, 4)), - (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1, 3, 4)), ) idxs = [ cu([1 2 3 4; @@ -15,90 +15,155 @@ idxs = [ 3 5 5 3]), # integer index cu([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)]), # tuple index - cu(CartesianIndex.([(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)])), # CartesianIndex index + (3,) (5,) (5,) (3,) + ]), # tuple index + cu(CartesianIndex.([ + (1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,) + ])), # CartesianIndex index ] types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] - @testset "scatter" begin - for T = types + for T in types @testset "$(T)" begin @testset "+" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "-" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "max" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "min" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) + end + end + end + end + + + # Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. + srcs_sp(T, fmt) = sparse(cu(ones(T, 3, 4)), fmt = fmt) + num_types = [Float32, Float64] + + @testset "scatter sparse-specialized" begin + for T in num_types + @testset "+" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(+, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "-" begin + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(-, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "max" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(max, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "min" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(min, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + # Sparse-specialized for operations not tested on eltype <: Integer + @testset "$(T)" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + @testset "*" begin + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(*, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "/" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(/, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "mean" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(mean, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end end end end end - for T = [CuArray{Float32}, CuArray{Float64}] + for T in [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "/" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "mean" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl index 18f307e0..1cceb563 100644 --- a/test/ext_cuda/test_utils.jl +++ b/test/ext_cuda/test_utils.jl @@ -20,3 +20,26 @@ function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, end end end + +function gputest(f, xs::Vararg{Union{AnyCuArray,AbstractCuSparseArray}}; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) + gpu_in = xs + cpu_in = map(CuArray, xs) + + cpu_out = f(cpu_in...; kws...) + gpu_out = f(gpu_in...; kws...) |> collect + @test collect(cpu_out) ≈ gpu_out rtol=rtol atol=atol broken=broken + + if checkgrad + # use mean instead of sum to prevent error accumulation (for larger + # tensors) which causes error to go above atol + cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...) + gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...) + for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) + if cpu_g === nothing + @test gpu_g === nothing + else + @test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad + end + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index baa19a1c..0c5f22c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,7 +155,7 @@ end using CUDA if CUDA.functional() @testset "CUDA" begin - nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) + nnlib_testsuite(CUDABackend; skip_tests=Set(("Gather"))) include("ext_cuda/runtests.jl") end