Skip to content
219 changes: 154 additions & 65 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
using Random: rand!

_fix_size(M, nrow, ncol) = M

# An immutable fixed size wrapper for matrices to work around
# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
# This is more-of-less a stripped down version of FixedSizeArrays
# which we can't easily use without pulling that into the standard library.
struct _FixedSizeMatrix{Trans,R}
ref::R
nrow::Int
ncol::Int
function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R}
new{Trans,R}(ref, nrow, ncol)
end
Comment thread
dkarrasch marked this conversation as resolved.
end
@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[]
@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v

@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) =
@inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v)

@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) =
@inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v)

@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol)
@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) =
_FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol)
@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) =
_FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol)

const tilebufsize = 10800 # Approximately 32k/3

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
Expand Down Expand Up @@ -74,47 +109,94 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
return C
end

# Slow non-inlined functions for throwing the error without messing up the caller
@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt)
if At == 'N'
Anames = "first", "second"
else
Anames = "second", "first"
end
if Bt == 'N'
Bnames = "first", "second"
else
Bnames = "second", "first"
end
nA == mB ||
throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB"))
mA == mC ||
throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC"))
nB == nC ||
throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC"))
# unreachable
throw(DimensionMismatch("Unknown dimension mismatch"))
end

@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt}
mC = size(C, 1)
nC = size(C, 2)
mA = size(A, 1)
nA = size(A, 2)
mB = size(B, 1)
nB = size(B, 2)

_mA, _nA = At == 'N' ? (mA, nA) : (nA, mA)
_mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB)

if (_nA != _mB) | (_mA != mC) | (_nB != nC)
_matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt)
end
return mC, nC, mA, nA, mB, nB
end

@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N'))
@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N'))
@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T'))

function _spmatmul!(C, A, B, α, β)
size(A, 2) == size(B, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
size(A, 1) == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Cax2 = axes(C, 2)
Aax2 = axes(A, 2)
mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B)
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in axes(C, 2)
@inbounds for col in axes(A,2)
αxj = B[col,k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
B = _fix_size(B, mB, nB)
C = _fix_size(C, mC, nC)
for k in Cax2
@inbounds for col in Aax2
αxj = α isa Bool ? B[col,k] : B[col,k] * α
for j in nzrange(A, col)
C[rv[j], k] += nzv[j]*αxj
rvj = rv[j]
C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k])
end
end
end
C
end

function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
size(A, 2) == size(C, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
size(A, 1) == size(B, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Cax2 = axes(C, 2)
Aax2 = axes(A, 2)
mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B)
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in axes(C, 2)
@inbounds for col in axes(A,2)
tmp = zero(eltype(C))
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc
B = _fix_size(B, mB, nB)
C = _fix_size(C, mC, nC)
for k in Cax2
@inbounds for col in Aax2
tmp = C0
for j in nzrange(A, col)
tmp += tfun(nzv[j])*B[rv[j],k]
tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp)
end
C[col,k] += tmp * α
C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k])
end
end
C
end

Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number, ::LinearAlgebra.BlasFlag.SyrkHerkGemm)
Expand All @@ -132,63 +214,71 @@ Base.@constprop :aggressive generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB,
LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)

function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
Aax2 = axes(A, 2)
Xax1 = axes(X, 1)
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for col in axes(A,2), k in nzrange(A, col)
Aiα = nzv[k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
X = _fix_size(X, mX, nX)
@inbounds for col in Aax2, k in nzrange(A, col)
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
rvk = rv[k]
@simd for multivec_row in axes(X,1)
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
@simd for multivec_row in Xax1
C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα,
C[multivec_row, col])
end
end
C
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
Xax1 = axes(X, 1)
Cax2 = axes(C, 2)
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for multivec_row in axes(X,1), col in axes(C, 2)
@inbounds for k in nzrange(A, col)
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
X = _fix_size(X, mX, nX)
@inbounds for multivec_row in Xax1, col in Cax2
nzrng = nzrange(A, col)
if isempty(nzrng)
continue
end
tmp = C[multivec_row, col]
for k in nzrng
tmp = muladd(X[multivec_row, rv[k]],
(α isa Bool ? nzv[k] : nzv[k] * α), tmp)
end
C[multivec_row, col] = tmp
end
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
mA == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
size(B, 1) == size(C, 2) ||
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Bax2 = axes(B, 2)
Aax1 = axes(A, 1)
mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B)
rv = rowvals(B)
nzv = nonzeros(B)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for col in axes(B, 2), k in nzrange(B, col)
Biα = tfun(nzv[k]) * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
A = _fix_size(A, mA, nA)
@inbounds for col in Bax2, k in nzrange(B, col)
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
rvk = rv[k]
@simd for multivec_col in axes(A,1)
C[multivec_col, rvk] += A[multivec_col, col] * Biα
@simd for multivec_col in Aax1
C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk])
end
end
C
end

function *(A::Diagonal, b::AbstractSparseVector)
Expand Down Expand Up @@ -1243,7 +1333,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
rv = rowvals(A)
nzv = nonzeros(A)
let z = T(0), sumcol=z, αxj=z, aarc=z, α = α
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for k in axes(B,2)
for col in axes(B,1)
αxj = B[col,k] * α
Expand All @@ -1262,7 +1352,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
end
end
end
C
end

# row range up to (and including if excl=false) diagonal
Expand Down
23 changes: 10 additions & 13 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1930,9 +1930,9 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector,
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -1946,7 +1946,6 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector,
end
end
end
return y
end

function _At_or_Ac_mul_B!(tfun::Function,
Expand All @@ -1958,14 +1957,14 @@ function _At_or_Ac_mul_B!(tfun::Function,
"Matrix A has $n rows, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m columns, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
_nnz = length(xnzind)
_nnz == 0 && return y
_nnz == 0 && return

Ty = promote_op(matprod, eltype(A), eltype(x))
@inbounds for j = 1:m
Expand All @@ -1975,7 +1974,7 @@ function _At_or_Ac_mul_B!(tfun::Function,
end
y[j] += s * α
end
return y
return
end

function *(A::AdjOrTrans{<:Any,<:StridedMatrix}, x::AbstractSparseVector)
Expand Down Expand Up @@ -2053,9 +2052,9 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -2073,7 +2072,6 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars
end
end
end
return y
end

function _At_or_Ac_mul_B!(tfun::Function,
Expand All @@ -2085,9 +2083,9 @@ function _At_or_Ac_mul_B!(tfun::Function,
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == n || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
n == 0 && return y
Comment thread
dkarrasch marked this conversation as resolved.
n == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -2102,7 +2100,6 @@ function _At_or_Ac_mul_B!(tfun::Function,
1, mx, xnzind, xnzval)
@inbounds y[j] += s * α
end
return y
end


Expand Down
Loading
Loading