Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/mkl/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
nnz::Ti
end

mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
colPtr::oneVector{Ti}
rowVal::oneVector{Ti}
nzVal::oneVector{Tv}
dims::NTuple{2,Int}
nnz::Ti
end

mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
rowInd::oneVector{Ti}
Expand All @@ -37,6 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal

for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
:oneSparseMatrixCSC => :SparseMatrixCSC,
:oneSparseMatrixCOO => :SparseMatrixCSC]
@eval Base.show(io::IOContext, x::$gpu) =
show(io, $cpu(x))
Expand Down
11 changes: 11 additions & 0 deletions lib/mkl/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N')
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
tA = tA in ('S', 's', 'H', 'h') ? 'T' : (tA == 'N' ? 'T' : 'N')
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
end

for SparseMatrixType in (:oneSparseMatrixCSR,)
@eval begin
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
Expand Down
202 changes: 202 additions & 0 deletions lib/mkl/wrappers_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
A_csc = SparseMatrixCSC(At |> transpose)
return A_csc
end

function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = size(A)
colPtr = oneVector{$intty}(A.colptr)
rowVal = oneVector{$intty}(A.rowval)
nzVal = oneVector{$elty}(A.nzval)
nnzA = length(A.nzval)
queue = global_queue(context(nzVal), device())
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
return A_csc
end
end
end

Expand Down Expand Up @@ -100,6 +121,78 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
end
end

# Special handling for CSC matrices since they are stored as transposed CSR
for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
(:onemklDsparse_gemv, :Float64),
(:onemklCsparse_gemv, :ComplexF32),
(:onemklZsparse_gemv, :ComplexF64))
@eval begin
function sparse_gemv!(trans::Char,
alpha::Number,
A::oneSparseMatrixCSC{$elty},
x::oneStridedVector{$elty},
beta::Number,
y::oneStridedVector{$elty})

# CSC(A) is represented by storing CSR(A^T). Map operations accordingly:
# - trans = 'N': want A*x -> use op(S)='T' with S=A^T.
# - trans = 'T': want A^T*x -> use op(S)='N' with S=A^T.
# - trans = 'C': want A^H*x.
# * For real eltypes, A^H == A^T -> use op(S)='N'.
# * For complex eltypes, we cannot express A^H using a single op(S).
# Use identity: conj(y_new) = conj(alpha) * A * conj(x) + conj(beta) * conj(y)
# and compute with op(S)='T' (since S^T = A), conjugating x and y around the call.

if trans == 'N'
queue = global_queue(context(x), device())
$fname(sycl_queue(queue), 'T', alpha, A.handle, x, beta, y)
return y
elseif trans == 'T'
queue = global_queue(context(x), device())
$fname(sycl_queue(queue), 'N', alpha, A.handle, x, beta, y)
return y
else
# trans == 'C'
if $elty <: Complex
# Compute A^H*x via identity:
# conj(y_new) = conj(alpha) * (A^T) * conj(x) + conj(beta) * conj(y)
# Since S=A^T and op='N' computes S*x = A^T*x, we can realize this with one call.
y .= conj.(y)
x_conj = similar(x)
x_conj .= conj.(x)

queue = global_queue(context(x), device())
$fname(sycl_queue(queue), 'N', conj(alpha), A.handle, x_conj, conj(beta), y)

y .= conj.(y)
return y
else
# real eltype: A^H == A^T
queue = global_queue(context(x), device())
$fname(sycl_queue(queue), 'N', alpha, A.handle, x, beta, y)
return y
end
end
end
end
end

function sparse_optimize_gemv!(trans::Char, A::oneSparseMatrixCSC)
# For CSC matrices stored as transposed CSR, we need to optimize with the transposed operation
if trans == 'N'
actual_trans = 'T'
elseif trans == 'T'
actual_trans = 'N'
else # trans == 'C'
# complex 'C' case is implemented using op='N' on S=A^T with conjugation trick
actual_trans = 'N'
end

queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemv(sycl_queue(queue), actual_trans, A.handle)
return A
end

for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
(:onemklDsparse_gemm, :Float64),
(:onemklCsparse_gemm, :ComplexF32),
Expand Down Expand Up @@ -139,6 +232,115 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars
return A
end

# Special handling for CSC matrices since they are stored as transposed CSR (S = A^T)
for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
(:onemklDsparse_gemm, :Float64),
(:onemklCsparse_gemm, :ComplexF32),
(:onemklZsparse_gemm, :ComplexF64))
@eval begin
function sparse_gemm!(transa::Char,
transb::Char,
alpha::Number,
A::oneSparseMatrixCSC{$elty},
B::oneStridedMatrix{$elty},
beta::Number,
C::oneStridedMatrix{$elty})

# Map op(A) to op(S) where S = A^T stored as CSR in the handle
# transa: 'N' -> op(S)='T'; 'T' -> op(S)='N'; 'C' ->
# real: op(S)='N' (since A^H == A^T)
# complex: use conjugation identity on B and C with op(S)='N'

mB, nB = size(B)
mC, nC = size(C)
(nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
(mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
nrhs = size(B, 2)
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))

queue = global_queue(context(C), device())

if transa == 'N'
# Want A * opB -> use S^T * opB
$fname(sycl_queue(queue), 'C', 'T', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
return C
elseif transa == 'T'
# Want A^T * opB -> use S * opB
$fname(sycl_queue(queue), 'C', 'N', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
return C
else
# transa == 'C'
if $elty <: Complex
# Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
# Prepare conj(C) in-place and conj(B) into a temporary if needed
C .= conj.(C)

# Determine how to supply opB under conjugation
# - transb == 'N': pass transb='N' and use conj(B)
# - transb == 'T': pass transb='T' and use conj(B)
# - transb == 'C': since conj(B^H) = B^T, pass transb='T' and use B as-is
local transb_eff::Char
local Beff
if transb == 'N'
transb_eff = 'N'
Beff = similar(B)
Beff .= conj.(B)
elseif transb == 'T'
transb_eff = 'T'
Beff = similar(B)
Beff .= conj.(B)
else
# transb == 'C'
transb_eff = 'T'
Beff = B
end

$fname(sycl_queue(queue), 'C', 'N', transb_eff, conj(alpha), A.handle, Beff, nrhs, ldb, conj(beta), C, ldc)

# Undo conjugation to obtain C_new
C .= conj.(C)
return C
else
# real eltype: A^H == A^T -> use S * opB
$fname(sycl_queue(queue), 'C', 'N', transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
return C
end
end
end
end
end

function sparse_optimize_gemm!(trans::Char, A::oneSparseMatrixCSC)
# Map requested op(A) to op(S) for S = A^T
actual_trans = if trans == 'N'
'T'
elseif trans == 'T'
'N'
else
# 'C' case: complex handled via conjugation with op(S)='N'; real reduces to 'T' which maps to 'N'
'N'
end
queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemm(sycl_queue(queue), actual_trans, A.handle)
return A
end

function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSparseMatrixCSC)
# Map as in the basic optimize, and adjust transB for the complex 'C' case if needed at runtime.
# We don't know eltype here; choose conservative mapping for A like above.
actual_trans = if trans == 'N'
'T'
elseif trans == 'T'
'N'
else
'N'
end
queue = global_queue(context(A.nzVal), device(A.nzVal))
onemklXsparse_optimize_gemm_advanced(sycl_queue(queue), 'C', actual_trans, transB, A.handle, nrhs)
return A
end

for (fname, elty) in ((:onemklSsparse_symv, :Float32),
(:onemklDsparse_symv, :Float64),
(:onemklCsparse_symv, :ComplexF32),
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
Loading