diff --git a/lib/mkl/array.jl b/lib/mkl/array.jl index f97b807c..d7120811 100644 --- a/lib/mkl/array.jl +++ b/lib/mkl/array.jl @@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} nnz::Ti end +mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} + handle::matrix_handle_t + colPtr::oneVector{Ti} + rowVal::oneVector{Ti} + nzVal::oneVector{Tv} + dims::NTuple{2,Int} + nnz::Ti +end + mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} handle::matrix_handle_t rowInd::oneVector{Ti} @@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC, + :oneSparseMatrixCSC => :SparseMatrixCSC, :oneSparseMatrixCOO => :SparseMatrixCSC] @eval Base.show(io::IOContext, x::$gpu) = show(io, $cpu(x)) diff --git a/lib/mkl/interfaces.jl b/lib/mkl/interfaces.jl index 725b120d..9e9100ca 100644 --- a/lib/mkl/interfaces.jl +++ b/lib/mkl/interfaces.jl @@ -7,12 +7,23 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A:: sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C) end +function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal + tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N') + sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C) +end + function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C) end +function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal + tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N') + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C) +end + for SparseMatrixType in (:oneSparseMatrixCSR,) @eval begin function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat diff --git a/lib/mkl/wrappers_sparse.jl b/lib/mkl/wrappers_sparse.jl index f39cbd03..e4c0fa7b 100644 --- a/lib/mkl/wrappers_sparse.jl +++ b/lib/mkl/wrappers_sparse.jl @@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3 A_csc = SparseMatrixCSC(At |> transpose) return A_csc end + + function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty}) + handle_ptr = Ref{matrix_handle_t}() + onemklXsparse_init_matrix_handle(handle_ptr) + m, n = size(A) + colPtr = oneVector{$intty}(A.colptr) + rowVal = oneVector{$intty}(A.rowval) + nzVal = oneVector{$elty}(A.nzval) + nnzA = length(A.nzval) + queue = global_queue(context(nzVal), device()) + $fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ + dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA) + finalizer(sparse_release_matrix_handle, dA) + return dA + end + + function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty}) + handle_ptr = Ref{matrix_handle_t}() + A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal)) + return A_csc + end end end @@ -100,6 +121,78 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO) end end +# Special handling for CSC matrices since they are stored as transposed CSR +for (fname, elty) in ((:onemklSsparse_gemv, :Float32), + (:onemklDsparse_gemv, :Float64), + (:onemklCsparse_gemv, :ComplexF32), + (:onemklZsparse_gemv, :ComplexF64)) + @eval begin + function sparse_gemv!(trans::Char, + alpha::Number, + A::oneSparseMatrixCSC{$elty}, + x::oneStridedVector{$elty}, + beta::Number, + y::oneStridedVector{$elty}) + + # CSC(A) is represented by storing CSR(A^T). Map operations accordingly: + # - trans = 'N': want A*x -> use op(S)='T' with S=A^T. + # - trans = 'T': want A^T*x -> use op(S)='N' with S=A^T. + # - trans = 'C': want A^H*x. + # * For real eltypes, A^H == A^T -> use op(S)='N'. + # * For complex eltypes, we cannot express A^H using a single op(S). + # Use identity: conj(y_new) = conj(alpha) * A * conj(x) + conj(beta) * conj(y) + # and compute with op(S)='T' (since S^T = A), conjugating x and y around the call. + + if trans == 'N' + queue = global_queue(context(x), device()) + $fname(sycl_queue(queue), 'T', alpha, A.handle, x, beta, y) + return y + elseif trans == 'T' + queue = global_queue(context(x), device()) + $fname(sycl_queue(queue), 'N', alpha, A.handle, x, beta, y) + return y + else + # trans == 'C' + if $elty <: Complex + # Compute A^H*x via identity: + # conj(y_new) = conj(alpha) * (A^T) * conj(x) + conj(beta) * conj(y) + # Since S=A^T and op='N' computes S*x = A^T*x, we can realize this with one call. + y .= conj.(y) + x_conj = similar(x) + x_conj .= conj.(x) + + queue = global_queue(context(x), device()) + $fname(sycl_queue(queue), 'N', conj(alpha), A.handle, x_conj, conj(beta), y) + + y .= conj.(y) + return y + else + # real eltype: A^H == A^T + queue = global_queue(context(x), device()) + $fname(sycl_queue(queue), 'N', alpha, A.handle, x, beta, y) + return y + end + end + end + end +end + +function sparse_optimize_gemv!(trans::Char, A::oneSparseMatrixCSC) + # For CSC matrices stored as transposed CSR, we need to optimize with the transposed operation + if trans == 'N' + actual_trans = 'T' + elseif trans == 'T' + actual_trans = 'N' + else # trans == 'C' + # complex 'C' case is implemented using op='N' on S=A^T with conjugation trick + actual_trans = 'N' + end + + queue = global_queue(context(A.nzVal), device(A.nzVal)) + onemklXsparse_optimize_gemv(sycl_queue(queue), actual_trans, A.handle) + return A +end + for (fname, elty) in ((:onemklSsparse_gemm, :Float32), (:onemklDsparse_gemm, :Float64), (:onemklCsparse_gemm, :ComplexF32), @@ -139,6 +232,115 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars return A end +# Special handling for CSC matrices since they are stored as transposed CSR (S = A^T) +for (fname, elty) in ((:onemklSsparse_gemm, :Float32), + (:onemklDsparse_gemm, :Float64), + (:onemklCsparse_gemm, :ComplexF32), + (:onemklZsparse_gemm, :ComplexF64)) + @eval begin + function sparse_gemm!(transa::Char, + transb::Char, + alpha::Number, + A::oneSparseMatrixCSC{$elty}, + B::oneStridedMatrix{$elty}, + beta::Number, + C::oneStridedMatrix{$elty}) + + # Map op(A) to op(S) where S = A^T stored as CSR in the handle + # transa: 'N' -> op(S)='T'; 'T' -> op(S)='N'; 'C' -> + # real: op(S)='N' (since A^H == A^T) + # complex: use conjugation identity on B and C with op(S)='N' + + mB, nB = size(B) + mC, nC = size(C) + (nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns.")) + (mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns.")) + nrhs = size(B, 2) + ldb = max(1,stride(B,2)) + ldc = max(1,stride(C,2)) + + queue = global_queue(context(C), device()) + + if transa == 'N' + # Want A * opB -> use S^T * opB + $fname(sycl_queue(queue), 'C', 'T', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc) + return C + elseif transa == 'T' + # Want A^T * opB -> use S * opB + $fname(sycl_queue(queue), 'C', 'N', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc) + return C + else + # transa == 'C' + if $elty <: Complex + # Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C) + # Prepare conj(C) in-place and conj(B) into a temporary if needed + C .= conj.(C) + + # Determine how to supply opB under conjugation + # - transb == 'N': pass transb='N' and use conj(B) + # - transb == 'T': pass transb='T' and use conj(B) + # - transb == 'C': since conj(B^H) = B^T, pass transb='T' and use B as-is + local transb_eff::Char + local Beff + if transb == 'N' + transb_eff = 'N' + Beff = similar(B) + Beff .= conj.(B) + elseif transb == 'T' + transb_eff = 'T' + Beff = similar(B) + Beff .= conj.(B) + else + # transb == 'C' + transb_eff = 'T' + Beff = B + end + + $fname(sycl_queue(queue), 'C', 'N', transb_eff, conj(alpha), A.handle, Beff, nrhs, ldb, conj(beta), C, ldc) + + # Undo conjugation to obtain C_new + C .= conj.(C) + return C + else + # real eltype: A^H == A^T -> use S * opB + $fname(sycl_queue(queue), 'C', 'N', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc) + return C + end + end + end + end +end + +function sparse_optimize_gemm!(trans::Char, A::oneSparseMatrixCSC) + # Map requested op(A) to op(S) for S = A^T + actual_trans = if trans == 'N' + 'T' + elseif trans == 'T' + 'N' + else + # 'C' case: complex handled via conjugation with op(S)='N'; real reduces to 'T' which maps to 'N' + 'N' + end + queue = global_queue(context(A.nzVal), device(A.nzVal)) + onemklXsparse_optimize_gemm(sycl_queue(queue), actual_trans, A.handle) + return A +end + +function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSparseMatrixCSC) + # Map as in the basic optimize, and adjust transB for the complex 'C' case if needed at runtime. + # We don't know eltype here; choose conservative mapping for A like above. + actual_trans = if trans == 'N' + 'T' + elseif trans == 'T' + 'N' + else + 'N' + end + queue = global_queue(context(A.nzVal), device(A.nzVal)) + onemklXsparse_optimize_gemm_advanced(sycl_queue(queue), 'C', actual_trans, transB, A.handle, nrhs) + return A +end + for (fname, elty) in ((:onemklSsparse_symv, :Float32), (:onemklDsparse_symv, :Float64), (:onemklCsparse_symv, :ComplexF32), diff --git a/test/Project.toml b/test/Project.toml index c214ed96..90670d48 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36" diff --git a/test/onemkl.jl b/test/onemkl.jl index e5b6541c..ab0a620d 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -3,7 +3,7 @@ if Sys.iswindows() else using oneAPI -using oneAPI.oneMKL: band, bandex, oneSparseMatrixCSR, oneSparseMatrixCOO +using oneAPI.oneMKL: band, bandex, oneSparseMatrixCSR, oneSparseMatrixCOO, oneSparseMatrixCSC using SparseArrays using LinearAlgebra @@ -1088,6 +1088,17 @@ end end end + @testset "oneSparseMatrixCSC" begin + (T isa Complex) && continue + for S in (Int32, Int64) + A = sprand(T, 20, 10, 0.5) + A = SparseMatrixCSC{T, S}(A) + B = oneSparseMatrixCSC(A) + A2 = SparseMatrixCSC(B) + @test A == A2 + end + end + @testset "oneSparseMatrixCOO" begin for S in (Int32, Int64) A = sprand(T, 20, 10, 0.5) @@ -1099,7 +1110,7 @@ end end @testset "sparse gemv" begin - @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCOO, oneSparseMatrixCSR) + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCOO, oneSparseMatrixCSR, oneSparseMatrixCSC) @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] A = sprand(T, 20, 10, 0.5) x = transa == 'N' ? rand(T, 10) : rand(T, 20) @@ -1119,119 +1130,129 @@ end end @testset "sparse gemm" begin - @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] - (transb == 'N') || continue - A = sprand(T, 10, 10, 0.5) - B = transb == 'N' ? rand(T, 10, 2) : rand(T, 2, 10) - C = rand(T, 10, 2) + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) + @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] + @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)] + (transb == 'N') || continue + A = sprand(T, 10, 10, 0.5) + B = transb == 'N' ? rand(T, 10, 2) : rand(T, 2, 10) + C = rand(T, 10, 2) - dA = oneSparseMatrixCSR(A) - dB = oneMatrix{T}(B) - dC = oneMatrix{T}(C) + dA = SparseMatrix(A) + dB = oneMatrix{T}(B) + dC = oneMatrix{T}(C) - alpha = rand(T) - beta = rand(T) - oneMKL.sparse_optimize_gemm!(transa, dA) - oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC) - @test alpha * opa(A) * opb(B) + beta * C ≈ collect(dC) + alpha = rand(T) + beta = rand(T) + oneMKL.sparse_optimize_gemm!(transa, dA) + oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC) + @test alpha * opa(A) * opb(B) + beta * C ≈ collect(dC) + end end end end @testset "sparse symv" begin - @testset "uplo = $uplo" for uplo in ('L', 'U') - A = sprand(T, 10, 10, 0.5) - A = A + A' - x = rand(T, 10) - y = rand(T, 10) - - dA = uplo == 'L' ? oneSparseMatrixCSR(A |> tril) : oneSparseMatrixCSR(A |> triu) - dx = oneVector{T}(x) - dy = oneVector{T}(y) - - alpha = rand(T) - beta = rand(T) - oneMKL.sparse_symv!(uplo, alpha, dA, dx, beta, dy) - # @test alpha * A * x + beta * y ≈ collect(dy) - end - end - - @testset "sparse trmv" begin - @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular), - ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] - (transa == 'N') || continue + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) + @testset "uplo = $uplo" for uplo in ('L', 'U') A = sprand(T, 10, 10, 0.5) + A = A + A' x = rand(T, 10) y = rand(T, 10) - B = uplo == 'L' ? tril(A) : triu(A) - B = diag == 'U' ? B - Diagonal(B) + I : B - dA = oneSparseMatrixCSR(B) + dA = uplo == 'L' ? SparseMatrix(A |> tril) : SparseMatrix(A |> triu) dx = oneVector{T}(x) dy = oneVector{T}(y) alpha = rand(T) beta = rand(T) - - oneMKL.sparse_optimize_trmv!(uplo, transa, diag, dA) - oneMKL.sparse_trmv!(uplo, transa, diag, alpha, dA, dx, beta, dy) - @test alpha * wrapper(opa(A)) * x + beta * y ≈ collect(dy) + oneMKL.sparse_symv!(uplo, alpha, dA, dx, beta, dy) + # @test alpha * A * x + beta * y ≈ collect(dy) end end end - @testset "sparse trsv" begin - @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular), - ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] - (transa == 'N') || continue - alpha = rand(T) - A = rand(T, 10, 10) + I - A = sparse(A) - x = rand(T, 10) - y = rand(T, 10) + @testset "sparse trmv" begin + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) + @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] + for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular), + ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] + (transa == 'N') || continue + A = sprand(T, 10, 10, 0.5) + x = rand(T, 10) + y = rand(T, 10) - B = uplo == 'L' ? tril(A) : triu(A) - B = diag == 'U' ? B - Diagonal(B) + I : B - dA = oneSparseMatrixCSR(B) - dx = oneVector{T}(x) - dy = oneVector{T}(y) + B = uplo == 'L' ? tril(A) : triu(A) + B = diag == 'U' ? B - Diagonal(B) + I : B + dA = SparseMatrix(B) + dx = oneVector{T}(x) + dy = oneVector{T}(y) + + alpha = rand(T) + beta = rand(T) - oneMKL.sparse_optimize_trsv!(uplo, transa, diag, dA) - oneMKL.sparse_trsv!(uplo, transa, diag, alpha, dA, dx, dy) - y = wrapper(opa(A)) \ (alpha * x) - @test y ≈ collect(dy) + oneMKL.sparse_optimize_trmv!(uplo, transa, diag, dA) + oneMKL.sparse_trmv!(uplo, transa, diag, alpha, dA, dx, beta, dy) + @test alpha * wrapper(opa(A)) * x + beta * y ≈ collect(dy) + end end end end - @testset "sparse trsm" begin - @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] - @testset "transx = $transx" for (transx, opx) in [('N', identity), ('T', transpose), ('C', adjoint)] - (transx != 'N') && continue + @testset "sparse trsv" begin + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) + @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular), - ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] + ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] (transa == 'N') || continue alpha = rand(T) A = rand(T, 10, 10) + I A = sparse(A) - X = transx == 'N' ? rand(T, 10, 4) : rand(T, 4, 10) - Y = rand(T, 10, 4) + x = rand(T, 10) + y = rand(T, 10) B = uplo == 'L' ? tril(A) : triu(A) B = diag == 'U' ? B - Diagonal(B) + I : B - dA = oneSparseMatrixCSR(B) - dX = oneMatrix{T}(X) - dY = oneMatrix{T}(Y) - - oneMKL.sparse_optimize_trsm!(uplo, transa, diag, dA) - oneMKL.sparse_trsm!(uplo, transa, transx, diag, alpha, dA, dX, dY) - Y = wrapper(opa(A)) \ (alpha * opx(X)) - @test Y ≈ collect(dY) + dA = SparseMatrix(B) + dx = oneVector{T}(x) + dy = oneVector{T}(y) + + oneMKL.sparse_optimize_trsv!(uplo, transa, diag, dA) + oneMKL.sparse_trsv!(uplo, transa, diag, alpha, dA, dx, dy) + y = wrapper(opa(A)) \ (alpha * x) + @test y ≈ collect(dy) + end + end + end + end - oneMKL.sparse_optimize_trsm!(uplo, transa, diag, 4, dA) + @testset "sparse trsm" begin + @testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC) + @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)] + @testset "transx = $transx" for (transx, opx) in [('N', identity), ('T', transpose), ('C', adjoint)] + (transx != 'N') && continue + for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular), + ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)] + (transa == 'N') || continue + alpha = rand(T) + A = rand(T, 10, 10) + I + A = sparse(A) + X = transx == 'N' ? rand(T, 10, 4) : rand(T, 4, 10) + Y = rand(T, 10, 4) + + B = uplo == 'L' ? tril(A) : triu(A) + B = diag == 'U' ? B - Diagonal(B) + I : B + dA = SparseMatrix(B) + dX = oneMatrix{T}(X) + dY = oneMatrix{T}(Y) + + oneMKL.sparse_optimize_trsm!(uplo, transa, diag, dA) + oneMKL.sparse_trsm!(uplo, transa, transx, diag, alpha, dA, dX, dY) + Y = wrapper(opa(A)) \ (alpha * opx(X)) + @test Y ≈ collect(dY) + + oneMKL.sparse_optimize_trsm!(uplo, transa, diag, 4, dA) + end end end end