Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/sparse/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ ROCSparseMatrixCSR{T}(Mat::SparseMatrixCSC) where {T} = ROCSparseMatrixCSR(ROCSp
ROCSparseMatrixBSR{T}(Mat::SparseMatrixCSC, blockdim) where {T} = ROCSparseMatrixBSR(ROCSparseMatrixCSR{T}(Mat), blockdim)
ROCSparseMatrixCOO{T}(Mat::SparseMatrixCSC) where {T} = ROCSparseMatrixCOO(ROCSparseMatrixCSR{T}(Mat))

# CPU sparse transpose/adjoint → ROC (Transpose/Adjoint of a CPU CSC)
ROCSparseMatrixCOO{T}(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {T, Tv} = ROCSparseMatrixCOO{T}(ROCSparseMatrixCSR{T}(Mat))
ROCSparseMatrixCOO{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} = ROCSparseMatrixCOO{T}(ROCSparseMatrixCSR{T}(Mat))

# untyped variants
ROCSparseVector(x::AbstractSparseArray{T}) where {T} = ROCSparseVector{T}(x)
ROCSparseMatrixCSC(x::AbstractSparseArray{T}) where {T} = ROCSparseMatrixCSC{T}(x)
Expand All @@ -420,8 +424,23 @@ ROCSparseMatrixCSR(x::Transpose{T}) where {T} = ROCSparseMatrixCSR{T}(x)
ROCSparseMatrixCSR(x::Adjoint{T}) where {T} = ROCSparseMatrixCSR{T}(x)
ROCSparseMatrixCSC(x::Transpose{T}) where {T} = ROCSparseMatrixCSC{T}(x)
ROCSparseMatrixCSC(x::Adjoint{T}) where {T} = ROCSparseMatrixCSC{T}(x)

# TODO adjoint / transpose: GPUArrays._sptranspose
ROCSparseMatrixCOO(x::Transpose{T}) where {T} = ROCSparseMatrixCOO{T}(x)
ROCSparseMatrixCOO(x::Adjoint{T}) where {T} = ROCSparseMatrixCOO{T}(x)

# GPU-to-GPU transpose/adjoint constructors:
# materialize the transposed/conjugate-transposed matrix using GPUArrays._sptranspose / _spadjoint (implemented in interfaces.jl).
ROCSparseMatrixCSR(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCSR(GPUArrays._sptranspose(parent(x)))
ROCSparseMatrixCSC(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCSC(GPUArrays._sptranspose(parent(x)))
ROCSparseMatrixCOO(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCOO(GPUArrays._sptranspose(parent(x)))
ROCSparseMatrixCSR(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCSR(GPUArrays._spadjoint(parent(x)))
ROCSparseMatrixCSC(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCSC(GPUArrays._spadjoint(parent(x)))
ROCSparseMatrixCOO(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} =
ROCSparseMatrixCOO(GPUArrays._spadjoint(parent(x)))

# gpu to cpu
SparseVector(x::ROCSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x)))
Expand Down
11 changes: 11 additions & 0 deletions src/sparse/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ for SparseMatrixType in [:ROCSparseMatrixCSC, :ROCSparseMatrixCSR]
end
end

ROCSparseMatrixCOO(S::Diagonal) = ROCSparseMatrixCOO(roc(S))
ROCSparseMatrixCOO(S::Diagonal{T, <:ROCArray}) where {T} = ROCSparseMatrixCOO{T}(S)
ROCSparseMatrixCOO{Tv}(S::Diagonal{T, <:ROCArray}) where {Tv, T} = ROCSparseMatrixCOO{Tv, Cint}(S)
function ROCSparseMatrixCOO{Tv, Ti}(S::Diagonal{T, <:ROCArray}) where {Tv, Ti, T}
m = size(S, 1)
return ROCSparseMatrixCOO{Tv, Ti}(ROCVector(1:m), ROCVector(1:m), Tv.(S.diag), (m, m), m)
end

# by flipping rows and columns, we can use that to get CSC to CSR too
for (elty, fname) in ((:Float32, :rocsparse_scsr2csc), (:Float64, :rocsparse_dcsr2csc),
(:ComplexF32, :rocsparse_ccsr2csc), (:ComplexF64, :rocsparse_zcsr2csc))
Expand Down Expand Up @@ -323,6 +331,8 @@ function ROCSparseMatrixCSR(coo::ROCSparseMatrixCOO{Tv}, ind::SparseChar='O') wh
rocsparse_coo2csr(handle(), coo.rowInd, nnz(coo), m, csrRowPtr, ind)
ROCSparseMatrixCSR{Tv}(csrRowPtr, coo.colInd, nonzeros(coo), size(coo))
end
# Typed forwarding constructors: allow ROCSparseMatrixCSR{Tv,Ti}(coo) as called by GPUArrays generics
ROCSparseMatrixCSR{Tv,Ti}(coo::ROCSparseMatrixCOO{Tv,Ti}) where {Tv,Ti} = ROCSparseMatrixCSR(coo)

function ROCSparseMatrixCOO(csr::ROCSparseMatrixCSR{Tv}, ind::SparseChar='O') where Tv
m,n = size(csr)
Expand All @@ -334,6 +344,7 @@ end
### CSC/BSR to COO and viceversa

ROCSparseMatrixCSC(coo::ROCSparseMatrixCOO) = ROCSparseMatrixCSC(ROCSparseMatrixCSR(coo)) # no direct conversion
ROCSparseMatrixCSC{Tv,Ti}(coo::ROCSparseMatrixCOO{Tv,Ti}) where {Tv,Ti} = ROCSparseMatrixCSC(coo)
ROCSparseMatrixCOO(csc::ROCSparseMatrixCSC) = ROCSparseMatrixCOO(ROCSparseMatrixCSR(csc)) # no direct conversion
ROCSparseMatrixBSR(coo::ROCSparseMatrixCOO, blockdim) = ROCSparseMatrixBSR(ROCSparseMatrixCSR(coo), blockdim) # no direct conversion
ROCSparseMatrixCOO(bsr::ROCSparseMatrixBSR) = ROCSparseMatrixCOO(ROCSparseMatrixCSR(bsr)) # no direct conversion
Expand Down
243 changes: 163 additions & 80 deletions src/sparse/interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
# interfacing with other packages

# Materialize transpose/adjoint of a sparse ROC matrix.
function GPUArrays._sptranspose(A::ROCSparseMatrixCSR)
ROCSparseMatrixCSR(ROCSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A))))
end
function GPUArrays._spadjoint(A::ROCSparseMatrixCSR)
ROCSparseMatrixCSR(ROCSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A))))
end

function GPUArrays._sptranspose(A::ROCSparseMatrixCSC)
ROCSparseMatrixCSC(ROCSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A))))
end
function GPUArrays._spadjoint(A::ROCSparseMatrixCSC)
ROCSparseMatrixCSC(ROCSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A))))
end

function GPUArrays._sptranspose(A::ROCSparseMatrixCOO)
# swap row/col indices and re-sort so the result is sorted by row
sparse(A.colInd, A.rowInd, A.nzVal, reverse(size(A))..., fmt=:coo)
end
function GPUArrays._spadjoint(A::ROCSparseMatrixCOO)
sparse(A.colInd, A.rowInd, conj(A.nzVal), reverse(size(A))..., fmt=:coo)
end

# (type_wrapper, value_unwrapper) pairs for generating methods over plain, transposed, and adjoint arguments.
const adjtrans_wrappers = (
(identity, identity),
(M -> :(Transpose{T, <:$M}), M -> :(_sptranspose(parent($M)))),
(M -> :(Adjoint{T, <:$M}), M -> :(_spadjoint(parent($M)))),
)

function mv_wrapper(
transa::SparseChar, alpha::Number, A::ROCSparseMatrix, X::DenseROCVector{T},
beta::Number, Y::ROCVector{T},
Expand Down Expand Up @@ -76,73 +106,41 @@ function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMa
mm!(tA, tB, alpha, A, B, beta, C, 'O')
end

Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, -one(eltype(A)), B, 'O')

Base.:(+)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A + Transpose(conj(B.parent))
Base.:(-)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A - Transpose(conj(B.parent))
Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) + B
Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) - B
Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T =
Transpose(conj(A.parent)) + B
Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T =
Transpose(conj(A.parent)) - B

function Base.:(+)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T
cscB = ROCSparseMatrixCSC(B.parent)
transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB))
return geam(one(T), A, one(T), transB, 'O')
end

function Base.:(-)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T
cscB = ROCSparseMatrixCSC(B.parent)
transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB))
return geam(one(T), A, -one(T), transB, 'O')
end

function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T
cscA = ROCSparseMatrixCSC(A.parent)
transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA))
geam(one(T), transA, one(T), B, 'O')
end

function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T
cscA = ROCSparseMatrixCSC(A.parent)
transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA))
geam(one(T), transA, -one(T), B, 'O')
end

function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T
C = geam(one(T), A.parent, one(T), B.parent, 'O')
cscC = ROCSparseMatrixCSC(C)
return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC))
end

function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T
C = geam(one(T), A.parent, -one(T), B.parent, 'O')
cscC = ROCSparseMatrixCSC(C)
return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC))
end

function Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix)
csrB = ROCSparseMatrixCSR(B)
return geam(one(eltype(A)), A, one(eltype(A)), csrB, 'O')
end

function Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix)
csrB = ROCSparseMatrixCSR(B)
return geam(one(eltype(A)), A, -one(eltype(A)), csrB, 'O')
# +/- for all combinations of plain/transposed/adjoint CSR and CSC, generated via adjtrans_wrappers.
for op in (:(+), :(-))
for (wrapa, unwrapa) in adjtrans_wrappers, (wrapb, unwrapb) in adjtrans_wrappers
for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}))
TypeA = wrapa(SparseMatrixType)
TypeB = wrapb(SparseMatrixType)
@eval Base.$op(A::$TypeA, B::$TypeB) where {T <: BlasFloat} =
geam(one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)), 'O')
end
end
# COO: materialise both sides as CSR, run geam, convert back
@eval Base.$op(
A::Union{ROCSparseMatrixCOO{T}, Transpose{T,<:ROCSparseMatrixCOO}, Adjoint{T,<:ROCSparseMatrixCOO}},
B::Union{ROCSparseMatrixCOO{T}, Transpose{T,<:ROCSparseMatrixCOO}, Adjoint{T,<:ROCSparseMatrixCOO}},
) where {T <: BlasFloat} =
ROCSparseMatrixCOO(Base.$op(ROCSparseMatrixCSR(A), ROCSparseMatrixCSR(B)))
end

function Base.:(+)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR)
csrA = ROCSparseMatrixCSR(A)
return geam(one(eltype(A)), csrA, one(eltype(A)), B, 'O')
# Cross-format +/- for CSR/CSC/BSR mixtures: normalise both operands to CSR, then geam
for op in (:(+), :(-))
@eval begin
Base.$op(A::ROCSparseMatrixCSR{T}, B::ROCSparseMatrixCSC{T}) where {T <: BlasFloat} =
geam(one(T), A, $(op)(one(T)), ROCSparseMatrixCSR(B), 'O')
Base.$op(A::ROCSparseMatrixCSC{T}, B::ROCSparseMatrixCSR{T}) where {T <: BlasFloat} =
geam(one(T), ROCSparseMatrixCSR(A), $(op)(one(T)), B, 'O')
Base.$op(A::ROCSparseMatrixCSR{T}, B::ROCSparseMatrixBSR{T}) where {T <: BlasFloat} =
geam(one(T), A, $(op)(one(T)), ROCSparseMatrixCSR(B), 'O')
Base.$op(A::ROCSparseMatrixBSR{T}, B::ROCSparseMatrixCSR{T}) where {T <: BlasFloat} =
geam(one(T), ROCSparseMatrixCSR(A), $(op)(one(T)), B, 'O')
end
end

function Base.:(-)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR)
csrA = ROCSparseMatrixCSR(A)
return geam(one(eltype(A)), csrA, -one(eltype(A)), B, 'O')
end
# vector +/-
Base.:(+)(A::ROCSparseVector{T}, B::ROCSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, one(T), B, 'O')
Base.:(-)(A::ROCSparseVector{T}, B::ROCSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, -one(T), B, 'O')

# triangular

Expand Down Expand Up @@ -199,10 +197,7 @@ end

## uniform scaling

# these operations materialize the identity matrix and re-use broadcast
# TODO: can we do without this, and just use the broadcast implementation
# with a singleton argument it knows how to index?

# TODO: use a broadcast singleton for I instead of materialising the full sparse identity.
function _sparse_identity(
::Type{<:ROCSparseMatrixCSR{<:Any,Ti}}, I::UniformScaling{Tv}, dims::Dims,
) where {Tv, Ti}
Expand All @@ -223,24 +218,112 @@ function _sparse_identity(
ROCSparseMatrixCSC{Tv,Ti}(colPtr, rowVal, nzVal, dims)
end

# TODO COO

Base.:(+)(A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}, J::UniformScaling) =
A .+ _sparse_identity(typeof(A), J, size(A))
function _sparse_identity(
::Type{<:ROCSparseMatrixCOO{<:Any,Ti}}, I::UniformScaling{Tv}, dims::Dims,
) where {Tv, Ti}
len = min(dims[1], dims[2])
rowInd = ROCVector{Ti}(1:len)
colInd = ROCVector{Ti}(1:len)
nzVal = AMDGPU.fill(I.λ, len)
ROCSparseMatrixCOO{Tv,Ti}(rowInd, colInd, nzVal, dims)
end

Base.:(-)(J::UniformScaling, A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}) =
_sparse_identity(typeof(A), J, size(A)) .- A
# Scale all nzVals of a COO matrix by scalar λ.
_coo_scale(A::ROCSparseMatrixCOO{T}, λ) where {T} =
ROCSparseMatrixCOO(A.rowInd, A.colInd, A.nzVal .* λ, size(A), nnz(A))

# UniformScaling +/-/* for all formats/wrappers; typeof(A′) passes a concrete type to _sparse_identity
# (SparseMatrixType is a UnionAll with Ti unbound at runtime and won't match its signatures).
for (wrapa, unwrapa) in adjtrans_wrappers
for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T}))
TypeA = wrapa(SparseMatrixType)
if SparseMatrixType != :(ROCSparseMatrixCOO{T})
# CSR/CSC: identity is the same format; broadcasting works for .*
@eval begin
Base.:(+)(A::$TypeA, J::UniformScaling) where {T} =
let A′ = $(unwrapa(:A)); A′ + _sparse_identity(typeof(A′), J, size(A)) end
Base.:(+)(J::UniformScaling, A::$TypeA) where {T} =
let A′ = $(unwrapa(:A)); _sparse_identity(typeof(A′), J, size(A)) + A′ end
Base.:(-)(A::$TypeA, J::UniformScaling) where {T} =
let A′ = $(unwrapa(:A)); A′ - _sparse_identity(typeof(A′), J, size(A)) end
Base.:(-)(J::UniformScaling, A::$TypeA) where {T} =
let A′ = $(unwrapa(:A)); _sparse_identity(typeof(A′), J, size(A)) - A′ end
Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) .* J.λ
Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = J.λ .* $(unwrapa(:A))
end
else
# COO: broadcast not supported → route +/- through CSR, scale nzVal for *
@eval begin
Base.:(+)(A::$TypeA, J::UniformScaling) where {T} =
let A′ = $(unwrapa(:A)); csr = ROCSparseMatrixCSR(A′)
ROCSparseMatrixCOO(csr + _sparse_identity(typeof(csr), J, size(A))) end
Base.:(+)(J::UniformScaling, A::$TypeA) where {T} =
let A′ = $(unwrapa(:A)); csr = ROCSparseMatrixCSR(A′)
ROCSparseMatrixCOO(_sparse_identity(typeof(csr), J, size(A)) + csr) end
Base.:(-)(A::$TypeA, J::UniformScaling) where {T} =
let A′ = $(unwrapa(:A)); csr = ROCSparseMatrixCSR(A′)
ROCSparseMatrixCOO(csr - _sparse_identity(typeof(csr), J, size(A))) end
Base.:(-)(J::UniformScaling, A::$TypeA) where {T} =
let A′ = $(unwrapa(:A)); csr = ROCSparseMatrixCSR(A′)
ROCSparseMatrixCOO(_sparse_identity(typeof(csr), J, size(A)) - csr) end
Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = _coo_scale($(unwrapa(:A)), J.λ)
Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = _coo_scale($(unwrapa(:A)), J.λ)
end
end
end
end

# TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse)
for SparseMatrixType in [:ROCSparseMatrixCSC, :ROCSparseMatrixCSR], op in [:(+), :(-)]
@eval begin
function Base.$op(lhs::Diagonal{T,<:ROCArray}, rhs::$SparseMatrixType{T}) where T
return $op($SparseMatrixType(lhs), rhs)
# +/- with Diagonal: convert it to the same sparse format, then geam.
for (wrapa, unwrapa) in adjtrans_wrappers, op in (:(+), :(-))
for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T}))
TypeA = wrapa(SparseMatrixType)
@eval begin
function Base.$op(lhs::Diagonal, rhs::$TypeA) where {T}
return $op($SparseMatrixType(lhs), $(unwrapa(:rhs)))
end
function Base.$op(lhs::$TypeA, rhs::Diagonal) where {T}
return $op($(unwrapa(:lhs)), $SparseMatrixType(rhs))
end
end
function Base.$op(lhs::$SparseMatrixType{T}, rhs::Diagonal{T,<:ROCArray}) where T
return $op(lhs, $SparseMatrixType(rhs))
end
end

# * with Diagonal for CSR/CSC: convert to COO, scale nzVal by d[colInd] or d[rowInd], convert back.
for (wrapa, unwrapa) in adjtrans_wrappers
for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}))
FmtCtor = SparseMatrixType == :(ROCSparseMatrixCSR{T}) ? :ROCSparseMatrixCSR : :ROCSparseMatrixCSC
TypeA = wrapa(SparseMatrixType)
@eval begin
function Base.:(*)(lhs::$TypeA, rhs::Diagonal) where {T}
A = $(unwrapa(:lhs))
d = rhs isa Diagonal{<:Any, <:ROCArray} ? T.(rhs.diag) : ROCArray(T.(rhs.diag))
coo = ROCSparseMatrixCOO(A)
$FmtCtor(ROCSparseMatrixCOO(coo.rowInd, coo.colInd, coo.nzVal .* d[coo.colInd], size(coo), nnz(coo)))
end
function Base.:(*)(lhs::Diagonal, rhs::$TypeA) where {T}
A = $(unwrapa(:rhs))
d = lhs isa Diagonal{<:Any, <:ROCArray} ? T.(lhs.diag) : ROCArray(T.(lhs.diag))
coo = ROCSparseMatrixCOO(A)
$FmtCtor(ROCSparseMatrixCOO(coo.rowInd, coo.colInd, d[coo.rowInd] .* coo.nzVal, size(coo), nnz(coo)))
end
end
end
end

# TODO _sptranspose / _spadjoint
# * with Diagonal for COO: scale nzVal by d[colInd] or d[rowInd].
for (wrapa, unwrapa) in adjtrans_wrappers
TypeA = wrapa(:(ROCSparseMatrixCOO{T}))
@eval begin
function Base.:(*)(lhs::$TypeA, rhs::Diagonal) where {T}
A = $(unwrapa(:lhs))
d = rhs isa Diagonal{<:Any, <:ROCArray} ? T.(rhs.diag) : ROCArray(T.(rhs.diag))
ROCSparseMatrixCOO(A.rowInd, A.colInd, A.nzVal .* d[A.colInd], size(A), nnz(A))
end
function Base.:(*)(lhs::Diagonal, rhs::$TypeA) where {T}
A = $(unwrapa(:rhs))
d = lhs isa Diagonal{<:Any, <:ROCArray} ? T.(lhs.diag) : ROCArray(T.(lhs.diag))
ROCSparseMatrixCOO(A.rowInd, A.colInd, d[A.rowInd] .* A.nzVal, size(A), nnz(A))
end
end
end
Loading