diff --git a/Project.toml b/Project.toml index 62e81db..cc5280c 100644 --- a/Project.toml +++ b/Project.toml @@ -9,16 +9,21 @@ StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" [extensions] +StridedcuBLASExt = "cuBLAS" StridedGPUArraysExt = "GPUArrays" +StridedAMDGPUExt = "AMDGPU" [compat] AMDGPU = "2" Aqua = "0.8" Adapt = "4" CUDACore = "6" +cuBLAS = "6" cuRAND = "6" GPUArrays = "11.4.1" JLArrays = "0.3.1" @@ -35,6 +40,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2" +cuBLAS = "182d3088-87b7-4494-8cad-fc6afaa545bc" cuRAND = "20fd9a0b-12d5-4c2f-a8af-7c34e9e60431" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -43,4 +49,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"] +test = ["Test", "Random", "Aqua", "AMDGPU", "CUDACore", "cuBLAS", "cuRAND", "GPUArrays", "JLArrays", "Metal", "Adapt"] diff --git a/ext/StridedAMDGPUExt.jl b/ext/StridedAMDGPUExt.jl new file mode 100644 index 0000000..627446e --- /dev/null +++ b/ext/StridedAMDGPUExt.jl @@ -0,0 +1,19 @@ +module StridedAMDGPUExt + +using Strided, StridedViews, AMDGPU, AMDGPU.rocBLAS, LinearAlgebra +import Strided: blas_mul! + +const ROCStridedView{T, N, A <: ROCArray{T}} = StridedViews.StridedView{T, N, A} + +function Strided.blas_mul!(C::ROCStridedView{T, 2}, A::ROCStridedView{T, 2}, B::ROCStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + A2, CA = Strided.getblasmatrix(A) + B2, CB = Strided.getblasmatrix(B) + C2, CC = Strided.getblasmatrix(C) + A2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(A2), size(A2)) + B2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(B2), size(B2)) + C2a = Base.unsafe_wrap(ROCMatrix{T}, pointer(C2), size(C2)) + AMDGPU.rocBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a) + return C +end + +end diff --git a/ext/StridedCUDACoreExt.jl b/ext/StridedCUDACoreExt.jl deleted file mode 100644 index a92bd97..0000000 --- a/ext/StridedCUDACoreExt.jl +++ /dev/null @@ -1,16 +0,0 @@ -module StridedCUDACoreExt - -using Strided, StridedViews, CUDACore -using CUDACore: Adapt, KernelAdaptor -using CUDACore: GPUArrays - -const ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} - -function Base.copy!(dst::StridedView{TD, ND, TAD, FD}, src::StridedView{TS, NS, TAS, FS}) where {TD <: Number, ND, TAD <: CuArray{TD}, FD <: ALL_FS, TS <: Number, NS, TAS <: CuArray{TS}, FS <: ALL_FS} - bc_style = Base.Broadcast.BroadcastStyle(TAS) - bc = Base.Broadcast.Broadcasted(bc_style, identity, (src,), axes(dst)) - GPUArrays._copyto!(dst, bc) - return dst -end - -end diff --git a/ext/StridedGPUArraysExt.jl b/ext/StridedGPUArraysExt.jl index 6bfb21a..04728e8 100644 --- a/ext/StridedGPUArraysExt.jl +++ b/ext/StridedGPUArraysExt.jl @@ -5,6 +5,8 @@ using GPUArrays: Adapt, KernelAbstractions using GPUArrays.KernelAbstractions: @kernel, @index using StridedViews: ParentIndex +import Strided: isblasmatrix + ALL_FS = Union{typeof(adjoint), typeof(conj), typeof(identity), typeof(transpose)} # StridedView backed by any GPU array type, with element type linked to the parent. @@ -129,4 +131,17 @@ function Strided._mapreduce_block!( return nothing end +function Strided.isblasmatrix(A::GPUStridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat} + if A.op == identity + # unsafe wrap approach doesn't work if second condition not met + return stride(A, 1) == 1 && size(A, 1) == size(parent(A), 1) + elseif A.op == conj + # this is converted to adjoint + # unsafe wrap approach doesn't work if second condition not met + return stride(A, 2) == 1 && size(A, 2) == size(parent(A), 2) + else # should never happen + return false + end +end + end diff --git a/ext/StridedcuBLASExt.jl b/ext/StridedcuBLASExt.jl new file mode 100644 index 0000000..3c41a2c --- /dev/null +++ b/ext/StridedcuBLASExt.jl @@ -0,0 +1,19 @@ +module StridedcuBLASExt + +using Strided, StridedViews, cuBLAS, cuBLAS.CUDACore, LinearAlgebra +import Strided: blas_mul! + +const CuStridedView{T, N, A <: CuArray{T}} = StridedViews.StridedView{T, N, A} + +function Strided.blas_mul!(C::CuStridedView{T, 2}, A::CuStridedView{T, 2}, B::CuStridedView{T, 2}, α::Number, β::Number) where {T <: LinearAlgebra.BlasFloat} + A2, CA = Strided.getblasmatrix(A) + B2, CB = Strided.getblasmatrix(B) + C2, CC = Strided.getblasmatrix(C) + A2a = Base.unsafe_wrap(CuMatrix{T}, pointer(A2), size(A2)) + B2a = Base.unsafe_wrap(CuMatrix{T}, pointer(B2), size(B2)) + C2a = Base.unsafe_wrap(CuMatrix{T}, pointer(C2), size(C2)) + cuBLAS.gemm!(CA, CB, convert(T, α), A2a, B2a, convert(T, β), C2a) + return C +end + +end diff --git a/src/linalg.jl b/src/linalg.jl index febc573..cb56c51 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -100,13 +100,21 @@ function _mul!( α::Number, β::Number ) where {T <: LinearAlgebra.BlasFloat} if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B) - nthreads = use_threaded_mul() ? get_num_threads() : 1 - _threaded_blas_mul!(C, A, B, α, β, nthreads) + return blas_mul!(C, A, B, α, β) else return __mul!(C, A, B, α, β) end end +# for CPU based arrays, this is valid +function blas_mul!( + C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, + α::Number, β::Number + ) where {T <: LinearAlgebra.BlasFloat} + nthreads = use_threaded_mul() ? get_num_threads() : 1 + return _threaded_blas_mul!(C, A, B, α, β, nthreads) +end + function _threaded_blas_mul!( C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2}, α::Number, β::Number, diff --git a/test/cuda.jl b/test/cuda.jl deleted file mode 100644 index bc24b26..0000000 --- a/test/cuda.jl +++ /dev/null @@ -1,30 +0,0 @@ -using Test -using Strided -using CUDA: CuMatrix -using CUDA.Adapt: adapt - -for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) - @testset "Copy with CuStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) - for m1 in (0, 16, 32), m2 in (0, 16, 32) - A1 = cuRAND.randn(T, (m1, m2)) - A2 = similar(A1) - zA1 = CuMatrix(f1(zeros(T, (m1, m2)))) - zA2 = CuMatrix(f2(zeros(T, (m1, m2)))) - A1c = copy(A1) - A2c = copy(A2) - B1 = f1(StridedView(A1c)) - B2 = f2(StridedView(A2c)) - axes(f1(A1)) == axes(f2(A2)) || continue - @test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDACore.Adapt.adapt(Vector{T}, copy!(B2, B1)) - @test copy!(zA1, f1(A1)) == copy!(zA2, B1) - A3 = CuArray(randn(T, (m1, m2))) - A3c = copy(A3) - B3 = f1(StridedView(A3c)) - @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted - @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted - @test CUDACore.Adapt.adapt(Vector{T}, f1(A1)) == CUDACore.Adapt.adapt(Vector{T}, B1) - x = rand(T) - @test f1(StridedView(CUDACore.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDACore.Adapt.adapt(Vector{T}, fill!(B1, x)) - end - end -end diff --git a/test/gpu.jl b/test/gpu.jl index 4959789..596ca00 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -14,10 +14,21 @@ end # types to test for ATs = [] !is_buildkite && push!(ATs, JLArray) -CUDACore.functional() && push!(ATs, CuArray) +CUDACore.functional() && cuBLAS.functional() && push!(ATs, CuArray) AMDGPU.functional() && push!(ATs, ROCArray) Metal.functional() && push!(ATs, MtlArray) +@testset "isblasmatrix ($AT)" for AT in ATs + for T in (Float32, ComplexF32) + A1 = StridedView(AT(randn(T, 20, 20))) + @test Strided.isblasmatrix(A1) + A2 = view(A1, 1:4:20, 1:5:20) + @test !Strided.isblasmatrix(A2) + A3 = view(conj!(A1), 1:4:20, 1:20) # stride(A3, 2) is not 1 + @test !Strided.isblasmatrix(A3) + end +end + @testset "in-place matrix operations ($AT)" for AT in ATs for T in (Float32, ComplexF32) A1 = StridedView(randn(T, 20, 20)) @@ -38,6 +49,38 @@ Metal.functional() && push!(ATs, MtlArray) end end +@testset "mul! ($AT{$T})" for AT in ATs, T in (Float32, ComplexF32) + N = 2 + α = rand(T) + β = rand(T) + dims = ntuple(Returns(div(64, N)), N) + A1 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A2 = permutedims(StridedView(rand(T, dims)), randperm(N)) + A3 = permutedims(StridedView(rand(T, dims)), randperm(N)) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, A1, A2, A3) + # test BLAS for all op combinations + @testset for sz in ((32, 64), (64, 64), (64, 32)) + vA1 = view(StridedView(rand(T, sz)), 1:32, 1:32) + vA2 = view(StridedView(rand(T, sz)), 1:32, 1:32) + vA3 = view(StridedView(rand(T, sz)), 1:32, 1:32) + @testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) + end + end + # non-BLAS fallback path + vA1 = view(StridedView(rand(T, (32, 32))), 1:32, 1:32) + vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) + vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + @testset for f1 in (identity, conj, adjoint, transpose), f2 in (identity, conj, adjoint, transpose) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, f1(vA2), f2(vA3)) + end + # non-BLAS fallback path + vA1 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + vA2 = view(StridedView(rand(T, (32, 64))), 1:32, 1:2:64) + vA3 = view(StridedView(rand(T, (64, 32))), 1:2:64, 1:32) + @test compare((C, A, B) -> mul!(C, A, B, α, β), AT, vA1, vA2, vA3) +end + @testset "map, scale!, axpy!, axpby! ($AT)" for AT in ATs for T in (Float32, ComplexF32) for N in 2:6 @@ -69,6 +112,22 @@ end end end +@testset "copy ($AT)" for AT in ATs + N = 2 + for m1 in (0, 16, 32), m2 in (0, 16, 32), T in (Float32, ComplexF32) + dims = (m1, m2) + A1 = StridedView(rand(T, dims)) + A2 = StridedView(rand(T, dims)) + A3 = StridedView(rand(T, dims)) + for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) + axes(f1(A1)) == axes(f2(A2)) || continue + B1 = f1(copy(A1)) + B2 = f2(copy(A2)) + @test compare((x, y) -> copy!(y, x), AT, B1, B2) + end + end +end + @testset "broadcasting ($AT)" for AT in ATs for T in (Float32, ComplexF32) A0 = StridedView(rand(T, ())) diff --git a/test/jlarrays.jl b/test/jlarrays.jl deleted file mode 100644 index fa163f6..0000000 --- a/test/jlarrays.jl +++ /dev/null @@ -1,25 +0,0 @@ -@testset for T in (Float32, Float64, Complex{Float32}, Complex{Float64}) - @testset "Copy with JLArrayStridedView: $T, $f1, $f2" for f2 in (identity, conj, adjoint, transpose), f1 in (identity, conj, transpose, adjoint) - for m1 in (0, 16, 32), m2 in (0, 16, 32) - A1 = JLArray(randn(T, (m1, m2))) - A2 = similar(A1) - zA1 = JLArray(f1(zeros(T, (m1, m2)))) - zA2 = JLArray(f2(zeros(T, (m1, m2)))) - A1c = copy(A1) - A2c = copy(A2) - B1 = f1(StridedView(A1c)) - B2 = f2(StridedView(A2c)) - axes(f1(A1)) == axes(f2(A2)) || continue - @test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1)) - @test copy!(zA1, f1(A1)) == copy!(zA2, B1) - A3 = JLArray(randn(T, (m1, m2))) - A3c = copy(A3) - B3 = f1(StridedView(A3c)) - @. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted - @. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted - @test JLArrays.Adapt.adapt(Vector{T}, f1(A1)) == JLArrays.Adapt.adapt(Vector{T}, B1) - x = rand(T) - @test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x)) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index e5d0da0..a091a4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using Aqua using Adapt, GPUArrays using JLArrays using AMDGPU -using CUDACore, cuRAND +using CUDACore, cuRAND, cuBLAS using Metal Random.seed!(1234)