From 472122ece30cb7388b0a35ab82daa75ba9ee4f14 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:13:33 +0100 Subject: [PATCH 01/11] Expanded dispatch of scatter! to include AbstractCuSparseArray --- ext/NNlibCUDAExt/scatter.jl | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 874207c77..15b7c1a08 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,4 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean +import CUDA: CUDA.CUSPARSE.AbstractcuSparseArray function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -23,7 +24,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) + j, k = divrem(index - 1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) end @@ -31,11 +32,11 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - max_idx, max_dims_idx, dims_size) where OP + max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index-1, max_dims_idx) + j, k = divrem(index - 1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) @@ -43,7 +44,10 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn return nothing end -function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) where OP + +function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{SpAnyCuArray,AbstractCuSparseArray}, + idx::Union{SpAnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -55,7 +59,7 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArra op, dst, src, idx, max_idx, max_dims_idx, dims_size end - kernel = @cuda launch=false scatter_kernel!(args...) + kernel = @cuda launch = false scatter_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) @@ -63,7 +67,8 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArra return dst end -function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) +function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, + src::Union{AnyCuArray,AbstractCuArray}, idx::Union{AnyCuArray,AbstractCuArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -74,7 +79,7 @@ end ## Gradients function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -94,7 +99,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -114,7 +119,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -137,7 +142,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -160,8 +165,8 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::AnyCuArray{Tsrc,Nsrc}, - idx::AnyCuArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx} + src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, + idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} dims = Nsrc - Nidx Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) rev_idx = NNlib.reverse_indices(idx) @@ -177,7 +182,7 @@ function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc end - kernel = @cuda launch=false ∇scatter_src_kernel!(args...) + kernel = @cuda launch = false ∇scatter_src_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) From cfccdb3037ba46bf29dd28660a6764bd0341721d Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:28:17 +0100 Subject: [PATCH 02/11] Fix incorrect import --- ext/NNlibCUDAExt/scatter.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 15b7c1a08..29e135a8d 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,5 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean -import CUDA: CUDA.CUSPARSE.AbstractcuSparseArray +import CUDA.CUSPARSE: AbstractCuSparseArray function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x From 1551922af093e11ef9f343ad9b2ee357216dcc1f Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:40:21 +0100 Subject: [PATCH 03/11] restore formatting --- ext/NNlibCUDAExt/scatter.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 29e135a8d..485fca39b 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -24,7 +24,7 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index - 1, max_dims_idx) + j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] CUDA.@atomic dst[Tuple(dims_i)..., idx[j+1]...] = op(dst[Tuple(dims_i)..., idx[j+1]...], src[index]) end @@ -32,11 +32,11 @@ function scatter_kernel!(op::OP, dst, src, idx, max_idx, max_dims_idx, dims_size end function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - max_idx, max_dims_idx, dims_size) where OP + max_idx, max_dims_idx, dims_size) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx - j, k = divrem(index - 1, max_dims_idx) + j, k = divrem(index-1, max_dims_idx) dims_i = CartesianIndices(dims_size)[k+1] li = Base._to_linear_index(dst, Tuple(dims_i)..., Tuple(idx[j+1])...) CUDA.@atomic dst[li] = op(dst[li], src[index]) @@ -59,7 +59,7 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, op, dst, src, idx, max_idx, max_dims_idx, dims_size end - kernel = @cuda launch = false scatter_kernel!(args...) + kernel = @cuda launch=false scatter_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) @@ -99,7 +99,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -142,7 +142,7 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx, end function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, - rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} + rev_idx, pre_cart_idx, max_dims_idx, max_idx, T::Type{TT}) where {OP,TT} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @inbounds if index <= max_idx @@ -182,7 +182,7 @@ function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, args = op, Δsrc, src, idx, rev_idx, pre_cart_idx, max_dims_idx, max_idx, Tsrc end - kernel = @cuda launch = false ∇scatter_src_kernel!(args...) + kernel = @cuda launch=false ∇scatter_src_kernel!(args...) config = launch_configuration(kernel.fun; max_threads=256) threads = min(max_idx, config.threads) blocks = cld(max_idx, threads) From c749dfcc11449eb604504b462e0bab676bf06d71 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Fri, 14 Nov 2025 15:58:59 +0100 Subject: [PATCH 04/11] Typos --- ext/NNlibCUDAExt/scatter.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 485fca39b..9b323d504 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -46,8 +46,8 @@ end function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{SpAnyCuArray,AbstractCuSparseArray}, - idx::Union{SpAnyCuArray,AbstractCuSparseArray}) where OP + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -68,7 +68,8 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, end function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuArray}, idx::Union{AnyCuArray,AbstractCuArray}) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) From 94f18bbd951563f21fe4828c5877f9e6f1006ffd Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Tue, 25 Nov 2025 22:37:47 +0100 Subject: [PATCH 05/11] Adding tests for scatter on CUSPARSE arrays --- test/ext_cuda/runtests.jl | 1 + test/ext_cuda/scatter.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index c0b140036..1db0b6ee9 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -4,6 +4,7 @@ using Zygote using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN +import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index a4977f285..10b660e40 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -21,7 +21,7 @@ idxs = [ (3,) (5,) (5,) (3,)])), # CartesianIndex index ] -types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] @testset "scatter" begin @@ -70,7 +70,7 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] end - for T = [CuArray{Float32}, CuArray{Float64}] + for T = [CuArray{Float32}, CuArray{Float64}, Sparse, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] @testset "$(T)" begin @testset "*" begin for idx = idxs, dims = [0, 1] From 098aca0d7bfc262c3c6cad084e2efae9a2fdf12a Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Tue, 27 Jan 2026 15:31:51 +0100 Subject: [PATCH 06/11] Fixing erroneously made tests for specialized sparse scatter kernels. --- test/ext_cuda/scatter.jl | 108 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index 10b660e40..421099aab 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -1,13 +1,13 @@ dsts = Dict( 0 => cu([3, 4, 5, 6, 7]), 1 => cu([3 3 4 4 5; - 5 5 6 6 7]), + 5 5 6 6 7]), ) srcs = Dict( (0, true) => cu(ones(Int, 3, 4)), (0, false) => cu(ones(Int, 3) * collect(1:4)'), (1, true) => cu(ones(Int, 2, 3, 4)), - (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1, 3, 4)), ) idxs = [ cu([1 2 3 4; @@ -70,7 +70,109 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuS end - for T = [CuArray{Float32}, CuArray{Float64}, Sparse, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] + # Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. + dsts_sp = Dict( + 0 => cu(sparse([3, 4, 5, 6, 7])), + 1 => cu(sparse([3 3 4 4 5; + 5 5 6 6 7])), + ) + srcs_sp = Dict( + (0, true) => cu(sparse(ones(Int, 3, 4))), + (0, false) => cu(sparse(ones(Int, 3) * collect(1:4)')), + # No sparse equivalent for 3D arrays + ) + types_sp = [ + CuSparseMatrixCSC{Int32}, CuSparseMatrixCSC{Int64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, + CuSparseMatrixCSR{Int32}, CuSparseMatrixCSR{Int64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, + CuSparseMatrixCOO{Int32}, CuSparseMatrixCOO{Int64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64} + ] + + @testset "scatter sparse-specialized" begin + for T = types_sp + @testset "$(T)" begin + @testset "+" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "-" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "max" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "min" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + end + end + + # Sparse-specialized for operations not tested on eltype <: Integer + for T = [CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64}] + @testset "$(T)" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + @testset "*" begin + for idx = idxs + mutated = true + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "/" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + + @testset "mean" begin + for idx = idxs, dims = [0, 1] + mutated = true + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + + mutated = false + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(0, mutated)]), checkgrad=true) + end + end + end + end + end + + + + for T = [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin for idx = idxs, dims = [0, 1] From ff2e79cdf75ab5d11ae850e4f921b4e110aa4437 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Tue, 14 Apr 2026 20:05:38 +0200 Subject: [PATCH 07/11] Adding sparse GPU tests for scatter --- test/ext_cuda/scatter.jl | 131 +++++++++++++----------------------- test/ext_cuda/test_utils.jl | 23 +++++++ 2 files changed, 69 insertions(+), 85 deletions(-) diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index 421099aab..d88cbfdb6 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -21,8 +21,7 @@ idxs = [ (3,) (5,) (5,) (3,)])), # CartesianIndex index ] -types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCOO{Float32}] - +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "scatter" begin for T = types @@ -70,105 +69,67 @@ types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}, CuS end - # Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. - dsts_sp = Dict( - 0 => cu(sparse([3, 4, 5, 6, 7])), - 1 => cu(sparse([3 3 4 4 5; - 5 5 6 6 7])), - ) - srcs_sp = Dict( - (0, true) => cu(sparse(ones(Int, 3, 4))), - (0, false) => cu(sparse(ones(Int, 3) * collect(1:4)')), - # No sparse equivalent for 3D arrays - ) - types_sp = [ - CuSparseMatrixCSC{Int32}, CuSparseMatrixCSC{Int64}, CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, - CuSparseMatrixCSR{Int32}, CuSparseMatrixCSR{Int64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, - CuSparseMatrixCOO{Int32}, CuSparseMatrixCOO{Int64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64} - ] - - @testset "scatter sparse-specialized" begin - for T = types_sp - @testset "$(T)" begin - @testset "+" begin - # Dims is implicitly 0. No sparse equivant for multidimensional src/dst - for idx = idxs - mutated = true - gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end - end - - @testset "-" begin - for idx = idxs - mutated = true - gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end - end +# Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. +srcs_sp(T, fmt) = sparse(cu(ones(T, 3, 4)), fmt = fmt) +num_types = [Float32, Float64] - @testset "max" begin - for idx = idxs - mutated = true - gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) +@testset "scatter sparse-specialized" begin + for T in num_types + @testset "+" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(+, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end - mutated = false - gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end - end + @testset "-" begin + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(-, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end - @testset "min" begin - for idx = idxs - mutated = true - gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) + @testset "max" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(max, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end - mutated = false - gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end - end + @testset "min" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(min, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end end # Sparse-specialized for operations not tested on eltype <: Integer - for T = [CuSparseMatrixCSC{Float32}, CuSparseMatrixCSC{Float64}, CuSparseMatrixCSR{Float32}, CuSparseMatrixCSR{Float64}, CuSparseMatrixCOO{Float32}, CuSparseMatrixCOO{Float64}] - @testset "$(T)" begin - # Dims is implicitly 0. No sparse equivant for multidimensional src/dst - @testset "*" begin - for idx = idxs - mutated = true - gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end + @testset "$(T)" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + @testset "*" begin + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(*, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end + end - @testset "/" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end + @testset "/" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(/, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end + end - @testset "mean" begin - for idx = idxs, dims = [0, 1] - mutated = true - gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[0])), T(srcs[(0, mutated)]), checkgrad=true) - - mutated = false - gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(0, mutated)]), checkgrad=true) - end + @testset "mean" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(mean, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end end end end +end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl index 18f307e0c..7ea139db6 100644 --- a/test/ext_cuda/test_utils.jl +++ b/test/ext_cuda/test_utils.jl @@ -20,3 +20,26 @@ function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, end end end + +function gputest(f, xs::Vararg{Union{GPUArrays.AbstractGPUArray, GPUArrays.AbstractGPUSparseArray}}; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) + gpu_in = xs + cpu_in = map(CuArray, xs) + + cpu_out = f(cpu_in...; kws...) + gpu_out = f(gpu_in...; kws...) |> collect + @test collect(cpu_out) ≈ gpu_out rtol=rtol atol=atol broken=broken + + if checkgrad + # use mean instead of sum to prevent error accumulation (for larger + # tensors) which causes error to go above atol + cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...) + gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...) + for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad) + if cpu_g === nothing + @test gpu_g === nothing + else + @test collect(cpu_g) ≈ collect(gpu_g) rtol=rtol atol=atol broken=broken_grad + end + end + end +end \ No newline at end of file From b62f857e01feb7f0ed2af312c93d617e8418701d Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Tue, 14 Apr 2026 20:34:58 +0200 Subject: [PATCH 08/11] Fixing type instability preventing correct calculation of gradients --- ext/NNlibCUDAExt/scatter.jl | 22 +++++++++++----------- test/ext_cuda/runtests.jl | 1 + 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index f9e3b6f14..1e05e26b1 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,5 +1,4 @@ # supported op: +, -, *, /, max, min, &, |, mean -import CUDA.CUSPARSE: AbstractCuSparseArray ## TODO support sparse dst/src/idx ## See issue https://github.com/FluxML/NNlib.jl/issues/647 @@ -55,10 +54,10 @@ function scatter_kernel!(op::OP, dst, src, idx::CUDA.CuDeviceArray{<:CartesianIn return nothing end -function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuSparseArray}, - idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP - isempty(idx) && return dst + +function NNlib.scatter!(op::OP, dst::AnyCuArray, + src::Union{AnyCuArray,AbstractGPUSparseArray}, + idx::Union{AnyCuArray,AbstractGPUSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -78,9 +77,9 @@ function NNlib.scatter!(op::OP, dst::Union{AnyCuArray,AbstractCuSparseArray}, return dst end -function NNlib.scatter!(op::typeof(mean), dst::Union{AnyCuArray,AbstractCuSparseArray}, - src::Union{AnyCuArray,AbstractCuSparseArray}, - idx::Union{AnyCuArray,AbstractCuSparseArray}) +function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, + src::Union{AnyCuArray,AbstractGPUSparseArray}, + idx::Union{AnyCuArray,AbstractGPUSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -177,10 +176,11 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::Union{AnyCuArray{Tsrc,Nsrc},AbstractCuSparseArray}, - idx::Union{AnyCuArray{Tidx,Nidx},AbstractCuSparseArray}) where {Tsrc,Tidx,Nsrc,Nidx} + src::Union{AnyCuArray,AbstractGPUSparseArray}, + idx::Union{AnyCuArray,AbstractGPUSparseArray}) dims = ndims(src) - ndims(idx) - Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), src) + gathered_src = CuArray(src) # Convert to dense to avoid type-unstable broadcast if sparse + Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), gathered_src) rev_idx = NNlib.reverse_indices(idx) rev_idx = CuArray(map(CUDA.cudaconvert, rev_idx)) diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 1db0b6ee9..279b53e86 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -5,6 +5,7 @@ using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO +import GPUArrays: AbstractGPUSparseArray using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) From 09bc7e9359668e74344fe8785c93a37d888d8720 Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Wed, 15 Apr 2026 11:06:12 +0200 Subject: [PATCH 09/11] Fixing missing dependency & actually running the scatter tests --- Project.toml | 1 + ext/NNlibCUDAExt/NNlibCUDAExt.jl | 1 + test/ext_cuda/runtests.jl | 1 + test/runtests.jl | 2 +- 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f0991e85..69353a197 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.9.33" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/NNlibCUDAExt/NNlibCUDAExt.jl b/ext/NNlibCUDAExt/NNlibCUDAExt.jl index 7939b8086..e4a676018 100644 --- a/ext/NNlibCUDAExt/NNlibCUDAExt.jl +++ b/ext/NNlibCUDAExt/NNlibCUDAExt.jl @@ -3,6 +3,7 @@ module NNlibCUDAExt using NNlib using CUDA using Random, Statistics +using GPUArrays: AbstractGPUSparseArray include("sampling.jl") include("activations.jl") diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 279b53e86..2c070297a 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -5,6 +5,7 @@ using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO +using GPUArrays import GPUArrays: AbstractGPUSparseArray using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) diff --git a/test/runtests.jl b/test/runtests.jl index baa19a1c4..0c5f22c37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,7 +155,7 @@ end using CUDA if CUDA.functional() @testset "CUDA" begin - nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather"))) + nnlib_testsuite(CUDABackend; skip_tests=Set(("Gather"))) include("ext_cuda/runtests.jl") end From ffa629a9d9beca1503549d0de4d0e9342461cb0a Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Wed, 15 Apr 2026 12:52:45 +0200 Subject: [PATCH 10/11] Dispatch scatter! only on specific CUDA sparse arrays --- ext/NNlibCUDAExt/NNlibCUDAExt.jl | 15 +++- ext/NNlibCUDAExt/scatter.jl | 23 ++--- test/ext_cuda/runtests.jl | 18 +++- test/ext_cuda/scatter.jl | 142 ++++++++++++++++--------------- test/ext_cuda/test_utils.jl | 2 +- 5 files changed, 108 insertions(+), 92 deletions(-) diff --git a/ext/NNlibCUDAExt/NNlibCUDAExt.jl b/ext/NNlibCUDAExt/NNlibCUDAExt.jl index e4a676018..338d56055 100644 --- a/ext/NNlibCUDAExt/NNlibCUDAExt.jl +++ b/ext/NNlibCUDAExt/NNlibCUDAExt.jl @@ -2,8 +2,21 @@ module NNlibCUDAExt using NNlib using CUDA +import CUDA.CUSPARSE: + AbstractCuSparseVector, + CuSparseMatrixCSC, + CuSparseMatrixCSR, + CuSparseMatrixBSR, + CuSparseMatrixCOO using Random, Statistics -using GPUArrays: AbstractGPUSparseArray + +const AbstractCuSparseArray{Tv,Ti} = Union{ + AbstractCuSparseVector{Tv,Ti}, + CuSparseMatrixCSC{Tv,Ti}, + CuSparseMatrixCSR{Tv,Ti}, + CuSparseMatrixBSR{Tv,Ti}, + CuSparseMatrixCOO{Tv,Ti}, +} include("sampling.jl") include("activations.jl") diff --git a/ext/NNlibCUDAExt/scatter.jl b/ext/NNlibCUDAExt/scatter.jl index 1e05e26b1..0b6dca0ea 100644 --- a/ext/NNlibCUDAExt/scatter.jl +++ b/ext/NNlibCUDAExt/scatter.jl @@ -1,16 +1,5 @@ # supported op: +, -, *, /, max, min, &, |, mean -## TODO support sparse dst/src/idx -## See issue https://github.com/FluxML/NNlib.jl/issues/647 -# import CUDA.CUSPARSE: AbstractCuSparseMatrix, CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, AnyCuSparseVector -# const AnyCuSparseMatrix{Tv,Ti} = Union{ -# AbstractCuSparseMatrix{Tv,Ti}, -# CUDA.CuSparseMatrixCSC{Tv,Ti}, # these types do not inherit from AbstractCuSparseMatrix -# CUDA.CuSparseMatrixCSR{Tv,Ti}, # but from GPUArrays.AbstractGPUSparseMatrixXXX -# CUDA.CuSparseMatrixCOO{Tv,Ti}, -# } -# const AnyCuSparseArray{Tv,Ti} = Union{AnyCuSparseVector{Tv,Ti},AnyCuSparseMatrix{Tv,Ti}} - function scatter_kernel!(op::OP, dst, src, idx) where OP index = threadIdx().x + (blockIdx().x - 1) * blockDim().x @@ -56,8 +45,8 @@ end function NNlib.scatter!(op::OP, dst::AnyCuArray, - src::Union{AnyCuArray,AbstractGPUSparseArray}, - idx::Union{AnyCuArray,AbstractGPUSparseArray}) where OP + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) where OP dims = NNlib.scatter_dims(dst, src, idx) args = if dims == 0 max_idx = length(idx) @@ -78,8 +67,8 @@ function NNlib.scatter!(op::OP, dst::AnyCuArray, end function NNlib.scatter!(op::typeof(mean), dst::AnyCuArray, - src::Union{AnyCuArray,AbstractGPUSparseArray}, - idx::Union{AnyCuArray,AbstractGPUSparseArray}) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) dst_ = NNlib.scatter!(+, zero(dst), src, idx) dst .+= NNlib.safe_div.(dst_, Ns) @@ -176,8 +165,8 @@ function ∇scatter_src_kernel!(op::OP, Δsrc, src, idx::CUDA.CuDeviceArray{<:Ca end function NNlib.∇scatter_src(op::Union{typeof(*),typeof(/)}, Δ, dst, - src::Union{AnyCuArray,AbstractGPUSparseArray}, - idx::Union{AnyCuArray,AbstractGPUSparseArray}) + src::Union{AnyCuArray,AbstractCuSparseArray}, + idx::Union{AnyCuArray,AbstractCuSparseArray}) dims = ndims(src) - ndims(idx) gathered_src = CuArray(src) # Convert to dense to avoid type-unstable broadcast if sparse Δsrc = NNlib.modify_src(op, NNlib.gather(Δ, idx), gathered_src) diff --git a/test/ext_cuda/runtests.jl b/test/ext_cuda/runtests.jl index 2c070297a..e07c41857 100644 --- a/test/ext_cuda/runtests.jl +++ b/test/ext_cuda/runtests.jl @@ -4,12 +4,24 @@ using Zygote using ForwardDiff: Dual using Statistics: mean using CUDA, cuDNN -import CUDA.CUSPARSE: CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO -using GPUArrays -import GPUArrays: AbstractGPUSparseArray +using SparseArrays +import CUDA.CUSPARSE: + AbstractCuSparseVector, + CuSparseMatrixCSC, + CuSparseMatrixCSR, + CuSparseMatrixBSR, + CuSparseMatrixCOO using NNlib: batchnorm, ∇batchnorm CUDA.allowscalar(false) +const AbstractCuSparseArray{Tv,Ti} = Union{ + AbstractCuSparseVector{Tv,Ti}, + CuSparseMatrixCSC{Tv,Ti}, + CuSparseMatrixCSR{Tv,Ti}, + CuSparseMatrixBSR{Tv,Ti}, + CuSparseMatrixCOO{Tv,Ti}, +} + include("test_utils.jl") include("activations.jl") include("dropout.jl") diff --git a/test/ext_cuda/scatter.jl b/test/ext_cuda/scatter.jl index d88cbfdb6..f7070162b 100644 --- a/test/ext_cuda/scatter.jl +++ b/test/ext_cuda/scatter.jl @@ -15,153 +15,155 @@ idxs = [ 3 5 5 3]), # integer index cu([(1,) (2,) (3,) (4,); (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)]), # tuple index - cu(CartesianIndex.([(1,) (2,) (3,) (4,); - (4,) (2,) (1,) (3,); - (3,) (5,) (5,) (3,)])), # CartesianIndex index + (3,) (5,) (5,) (3,) + ]), # tuple index + cu(CartesianIndex.([ + (1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,) + ])), # CartesianIndex index ] types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] @testset "scatter" begin - for T = types + for T in types @testset "$(T)" begin @testset "+" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(+, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(+, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "-" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(-, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(-, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "max" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(max, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(max, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "min" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(min, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(min, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end end end -# Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. -srcs_sp(T, fmt) = sparse(cu(ones(T, 3, 4)), fmt = fmt) -num_types = [Float32, Float64] - -@testset "scatter sparse-specialized" begin - for T in num_types - @testset "+" begin - # Dims is implicitly 0. No sparse equivant for multidimensional src - for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] - # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(+, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) - end - end - - @testset "-" begin - for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] - # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(-, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) - end - end - - @testset "max" begin - for idx in idxs, fmt in [:csc, :csr] - # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(max, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) - end - end + # Specialized sparse scatter kernels. Duplicated as test cases above do not cover sparse arrays. + srcs_sp(T, fmt) = sparse(cu(ones(T, 3, 4)), fmt = fmt) + num_types = [Float32, Float64] - @testset "min" begin - for idx in idxs, fmt in [:csc, :csr] - # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(min, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + @testset "scatter sparse-specialized" begin + for T in num_types + @testset "+" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(+, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end end - end - # Sparse-specialized for operations not tested on eltype <: Integer - @testset "$(T)" begin - # Dims is implicitly 0. No sparse equivant for multidimensional src/dst - @testset "*" begin + @testset "-" begin for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(*, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + gputest((dst, src, idx) -> NNlib.scatter!(-, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end end - @testset "/" begin + @testset "max" begin for idx in idxs, fmt in [:csc, :csr] # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(/, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + gputest((dst, src, idx) -> NNlib.scatter!(max, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) end end - @testset "mean" begin + @testset "min" begin for idx in idxs, fmt in [:csc, :csr] # mutated = true - gputest((dst, src, idx) -> NNlib.scatter!(mean, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + gputest((dst, src, idx) -> NNlib.scatter!(min, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + # Sparse-specialized for operations not tested on eltype <: Integer + @testset "$(T)" begin + # Dims is implicitly 0. No sparse equivant for multidimensional src/dst + @testset "*" begin + for idx in idxs, fmt in [:csc, :csr, :bsr, :coo] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(*, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "/" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(/, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end + end + + @testset "mean" begin + for idx in idxs, fmt in [:csc, :csr] + # mutated = true + gputest((dst, src, idx) -> NNlib.scatter!(mean, dst, src, idx), copy(dsts[0]), srcs_sp(T, fmt), idx) + end end end end end -end - - for T = [CuArray{Float32}, CuArray{Float64}] + for T in [CuArray{Float32}, CuArray{Float64}] @testset "$(T)" begin @testset "*" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(*, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(*, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "/" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(/, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(/, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end @testset "mean" begin - for idx = idxs, dims = [0, 1] + for idx in idxs, dims in [0, 1] mutated = true - gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad=true) + gputest((dst, src) -> NNlib.scatter!(mean, dst, src, idx), T(copy(dsts[dims])), T(srcs[(dims, mutated)]), checkgrad = true) mutated = false - gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad=true) + gputest(src -> NNlib.scatter(mean, src, idx), T(srcs[(dims, mutated)]), checkgrad = true) end end end diff --git a/test/ext_cuda/test_utils.jl b/test/ext_cuda/test_utils.jl index 7ea139db6..1cceb563e 100644 --- a/test/ext_cuda/test_utils.jl +++ b/test/ext_cuda/test_utils.jl @@ -21,7 +21,7 @@ function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, end end -function gputest(f, xs::Vararg{Union{GPUArrays.AbstractGPUArray, GPUArrays.AbstractGPUSparseArray}}; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) +function gputest(f, xs::Vararg{Union{AnyCuArray,AbstractCuSparseArray}}; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...) gpu_in = xs cpu_in = map(CuArray, xs) From 40fcc561d6fcbce087c1f4b3ef3902833d8ae51c Mon Sep 17 00:00:00 2001 From: Alonso Martinez Cisneros Date: Tue, 5 May 2026 12:03:58 +0200 Subject: [PATCH 11/11] Fixing similar issues with gather! --- src/NNlib.jl | 2 ++ src/gather.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/NNlib.jl b/src/NNlib.jl index 5ea907783..41b41560d 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -7,6 +7,8 @@ using Base.Broadcast: broadcasted using Base.Threads using ChainRulesCore using GPUArraysCore +using GPUArrays +using GPUArrays: AbstractGPUSparseArray using KernelAbstractions using KernelAbstractions: @atomic using LinearAlgebra diff --git a/src/gather.jl b/src/gather.jl index 7997f8784..8d9e20a0e 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -109,7 +109,7 @@ function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray) return dst end -function gather!(dst::AnyGPUArray, src::AnyGPUArray, idx::AnyGPUArray) +function gather!(dst::AnyGPUArray, src::Union{AnyGPUArray, AbstractGPUSparseArray}, idx::AnyGPUArray) isempty(dst) && return dst n_dims = scatter_dims(src, dst, idx) dims = size(src)[1:n_dims]