Skip to content

Commit 928549d

Browse files
authored
Merge branch 'main' into dk/cleanup-matprod_dest
2 parents 47667ec + 0c2b44c commit 928549d

11 files changed

Lines changed: 247 additions & 94 deletions

File tree

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
run: which julia
1616
continue-on-error: true
1717
- name: Install Julia, but only if it is not already available in the PATH
18-
uses: julia-actions/setup-julia@v2
18+
uses: julia-actions/setup-julia@v3
1919
with:
2020
version: '1'
2121
arch: ${{ runner.arch }}

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
julia-version: 'nightly'
3636
steps:
3737
- uses: actions/checkout@v6
38-
- uses: julia-actions/setup-julia@v2
38+
- uses: julia-actions/setup-julia@v3
3939
with:
4040
version: ${{ matrix.julia-version }}
4141
arch: ${{ matrix.julia-arch }}
@@ -63,7 +63,7 @@ jobs:
6363
- x64
6464
steps:
6565
- uses: actions/checkout@v6
66-
- uses: julia-actions/setup-julia@v2
66+
- uses: julia-actions/setup-julia@v3
6767
with:
6868
version: ${{ matrix.julia-version }}
6969
arch: ${{ matrix.julia-arch }}
@@ -76,7 +76,7 @@ jobs:
7676
runs-on: ubuntu-latest
7777
steps:
7878
- uses: actions/checkout@v6
79-
- uses: julia-actions/setup-julia@v2
79+
- uses: julia-actions/setup-julia@v3
8080
with:
8181
version: 'nightly'
8282
- name: Generate docs

src/SparseArrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import Base: Matrix, Vector
1919
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
2020
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded, isdiag,
2121
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
22-
matop_dest, generic_matvecmul!, generic_matmatmul!, generic_matmatmul_wrapper!, copytrito!
22+
matop_dest, generic_matvecmul!, generic_matmatmul!, generic_matmatmul_wrapper!, copytrito!, nonzeroinds
2323

2424
import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
2525
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,
@@ -87,7 +87,7 @@ if Base.USE_GPL_LIBS
8787
include("solvers/spqr.jl")
8888
end
8989

90-
zero(a::AbstractSparseArray) = spzeros(eltype(a), size(a)...)
90+
zero(a::AbstractSparseArray{Tv,Ti}) where {Tv,Ti} = spzeros(Tv, Ti, size(a)...)
9191

9292
LinearAlgebra.diagzero(D::Diagonal{<:AbstractSparseMatrix{T}},i,j) where {T} =
9393
spzeros(T, size(D.diag[i], 1), size(D.diag[j], 2))

src/linalg.jl

Lines changed: 155 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
44
RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55
using Random: rand!
66

7+
_fix_size(M, nrow, ncol) = M
8+
9+
# An immutable fixed size wrapper for matrices to work around
10+
# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
11+
# This is more-of-less a stripped down version of FixedSizeArrays
12+
# which we can't easily use without pulling that into the standard library.
13+
struct _FixedSizeMatrix{Trans,R}
14+
ref::R
15+
nrow::Int
16+
ncol::Int
17+
function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R}
18+
new{Trans,R}(ref, nrow, ncol)
19+
end
20+
end
21+
@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) =
22+
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[]
23+
@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) =
24+
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v
25+
26+
@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) =
27+
@inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
28+
@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) =
29+
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v)
30+
31+
@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) =
32+
@inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
33+
@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) =
34+
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v)
35+
36+
@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol)
37+
@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) =
38+
_FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol)
39+
@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) =
40+
_FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol)
41+
742
const tilebufsize = 10800 # Approximately 32k/3
843

944
# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -64,52 +99,99 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
6499
T = eltype(C)
65100
_mul!(rangefun, diagop, odiagop, C, A, wrap(B, tB), T(alpha), T(beta))
66101
else
67-
@stable_muladdmul LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta))
102+
LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
68103
end
69104
return C
70105
end
71106

107+
# Slow non-inlined functions for throwing the error without messing up the caller
108+
@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt)
109+
if At == 'N'
110+
Anames = "first", "second"
111+
else
112+
Anames = "second", "first"
113+
end
114+
if Bt == 'N'
115+
Bnames = "first", "second"
116+
else
117+
Bnames = "second", "first"
118+
end
119+
nA == mB ||
120+
throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB"))
121+
mA == mC ||
122+
throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC"))
123+
nB == nC ||
124+
throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC"))
125+
# unreachable
126+
throw(DimensionMismatch("Unknown dimension mismatch"))
127+
end
128+
129+
@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt}
130+
mC = size(C, 1)
131+
nC = size(C, 2)
132+
mA = size(A, 1)
133+
nA = size(A, 2)
134+
mB = size(B, 1)
135+
nB = size(B, 2)
136+
137+
_mA, _nA = At == 'N' ? (mA, nA) : (nA, mA)
138+
_mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB)
139+
140+
if (_nA != _mB) | (_mA != mC) | (_nB != nC)
141+
_matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt)
142+
end
143+
return mC, nC, mA, nA, mB, nB
144+
end
145+
146+
@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N'))
147+
@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N'))
148+
@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T'))
149+
72150
function _spmatmul!(C, A, B, α, β)
73-
size(A, 2) == size(B, 1) ||
74-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
75-
size(A, 1) == size(C, 1) ||
76-
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
77-
size(B, 2) == size(C, 2) ||
78-
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
151+
Cax2 = axes(C, 2)
152+
Aax2 = axes(A, 2)
153+
mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B)
79154
nzv = nonzeros(A)
80155
rv = rowvals(A)
81-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
82-
for k in axes(C, 2)
83-
@inbounds for col in axes(A,2)
84-
αxj = B[col,k] * α
156+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
157+
if α isa Bool && !α
158+
return
159+
end
160+
B = _fix_size(B, mB, nB)
161+
C = _fix_size(C, mC, nC)
162+
for k in Cax2
163+
@inbounds for col in Aax2
164+
αxj = α isa Bool ? B[col,k] : B[col,k] * α
85165
for j in nzrange(A, col)
86-
C[rv[j], k] += nzv[j]*αxj
166+
rvj = rv[j]
167+
C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k])
87168
end
88169
end
89170
end
90-
C
91171
end
92172

93173
function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
94-
size(A, 2) == size(C, 1) ||
95-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
96-
size(A, 1) == size(B, 1) ||
97-
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
98-
size(B, 2) == size(C, 2) ||
99-
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
174+
Cax2 = axes(C, 2)
175+
Aax2 = axes(A, 2)
176+
mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B)
100177
nzv = nonzeros(A)
101178
rv = rowvals(A)
102-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
103-
for k in axes(C, 2)
104-
@inbounds for col in axes(A,2)
105-
tmp = zero(eltype(C))
179+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
180+
if α isa Bool && !α
181+
return
182+
end
183+
C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc
184+
B = _fix_size(B, mB, nB)
185+
C = _fix_size(C, mC, nC)
186+
for k in Cax2
187+
@inbounds for col in Aax2
188+
tmp = C0
106189
for j in nzrange(A, col)
107-
tmp += tfun(nzv[j])*B[rv[j],k]
190+
tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp)
108191
end
109-
C[col,k] += tmp * α
192+
C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k])
110193
end
111194
end
112-
C
113195
end
114196

115197
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number, ::LinearAlgebra.BlasFlag.SyrkHerkGemm)
@@ -127,63 +209,71 @@ Base.@constprop :aggressive generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB,
127209
LinearAlgebra._generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
128210

129211
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
130-
mX, nX = size(X)
131-
nX == size(A, 1) ||
132-
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
133-
mX == size(C, 1) ||
134-
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
135-
size(A, 2) == size(C, 2) ||
136-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
212+
Aax2 = axes(A, 2)
213+
Xax1 = axes(X, 1)
214+
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
137215
rv = rowvals(A)
138216
nzv = nonzeros(A)
139-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
140-
@inbounds for col in axes(A,2), k in nzrange(A, col)
141-
Aiα = nzv[k] * α
217+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
218+
if α isa Bool && !α
219+
return
220+
end
221+
C = _fix_size(C, mC, nC)
222+
X = _fix_size(X, mX, nX)
223+
@inbounds for col in Aax2, k in nzrange(A, col)
224+
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
142225
rvk = rv[k]
143-
@simd for multivec_row in axes(X,1)
144-
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
226+
@simd for multivec_row in Xax1
227+
C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα,
228+
C[multivec_row, col])
145229
end
146230
end
147-
C
148231
end
149232
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
150-
mX, nX = size(X)
151-
nX == size(A, 1) ||
152-
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
153-
mX == size(C, 1) ||
154-
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
155-
size(A, 2) == size(C, 2) ||
156-
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
233+
Xax1 = axes(X, 1)
234+
Cax2 = axes(C, 2)
235+
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
157236
rv = rowvals(A)
158237
nzv = nonzeros(A)
159-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
160-
for multivec_row in axes(X,1), col in axes(C, 2)
161-
@inbounds for k in nzrange(A, col)
162-
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
238+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
239+
if α isa Bool && !α
240+
return
241+
end
242+
C = _fix_size(C, mC, nC)
243+
X = _fix_size(X, mX, nX)
244+
@inbounds for multivec_row in Xax1, col in Cax2
245+
nzrng = nzrange(A, col)
246+
if isempty(nzrng)
247+
continue
248+
end
249+
tmp = C[multivec_row, col]
250+
for k in nzrng
251+
tmp = muladd(X[multivec_row, rv[k]],
252+
isa Bool ? nzv[k] : nzv[k] * α), tmp)
163253
end
254+
C[multivec_row, col] = tmp
164255
end
165-
C
166256
end
167257

168258
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
169-
mA, nA = size(A)
170-
nA == size(B, 2) ||
171-
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
172-
mA == size(C, 1) ||
173-
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
174-
size(B, 1) == size(C, 2) ||
175-
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
259+
Bax2 = axes(B, 2)
260+
Aax1 = axes(A, 1)
261+
mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B)
176262
rv = rowvals(B)
177263
nzv = nonzeros(B)
178-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
179-
@inbounds for col in axes(B, 2), k in nzrange(B, col)
180-
Biα = tfun(nzv[k]) * α
264+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
265+
if α isa Bool && !α
266+
return
267+
end
268+
C = _fix_size(C, mC, nC)
269+
A = _fix_size(A, mA, nA)
270+
@inbounds for col in Bax2, k in nzrange(B, col)
271+
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
181272
rvk = rv[k]
182-
@simd for multivec_col in axes(A,1)
183-
C[multivec_col, rvk] += A[multivec_col, col] * Biα
273+
@simd for multivec_col in Aax1
274+
C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk])
184275
end
185276
end
186-
C
187277
end
188278

189279
function *(A::Diagonal, b::AbstractSparseVector)
@@ -1238,7 +1328,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12381328
rv = rowvals(A)
12391329
nzv = nonzeros(A)
12401330
let z = T(0), sumcol=z, αxj=z, aarc=z, α = α
1241-
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
1331+
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
12421332
@inbounds for k in axes(B,2)
12431333
for col in axes(B,1)
12441334
αxj = B[col,k] * α
@@ -1257,7 +1347,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12571347
end
12581348
end
12591349
end
1260-
C
12611350
end
12621351

12631352
# row range up to (and including if excl=false) diagonal

src/solvers/cholmod.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ get_perm(FC::FactorComponent) = get_perm(Factor(FC))
819819
# Conversion/construction
820820

821821
function Dense{T}(A::StridedVecOrMatInclAdjAndTrans) where T<:VTypes
822-
d = allocate_dense(size(A, 1), size(A, 2), A isa StridedVecOrMat ? stride(A, 2) : size(A, 1), T)
822+
d = allocate_dense(size(A, 1), size(A, 2), size(A, 1), T)
823823
D = unsafe_wrap(Array, Ptr{eltype(d)}(unsafe_load(pointer(d)).x), size(A), own = false)
824824
copyto!(D, A)
825825
return d
@@ -990,7 +990,7 @@ function Sparse(A::SparseMatrixCSC{<:Union{ComplexF16, ComplexF32}}, stype::Inte
990990
end
991991

992992
# convert SparseVectors into CHOLMOD Sparse types through a mx1 CSC matrix
993-
Sparse(A::SparseVector) = Sparse(SparseMatrixCSC(A))
993+
Sparse(A::SparseVector) = Sparse(SparseMatrixCSC(A), 0)
994994
function Sparse{Tv, Ti}(A::SparseMatrixCSC) where {Tv<:VTypes, Ti<:ITypes}
995995
o = Sparse{Tv, Ti}(A, 0)
996996
# check if array is symmetric and change stype if it is
@@ -1135,6 +1135,7 @@ function SparseVector{Tv, Ti}(A::Sparse{Tv, Ti}) where {Tv, Ti<:ITypes}
11351135
end
11361136
args = _extract_args(s, Tv)
11371137
s.sorted == 0 && _sort_buffers!(args...);
1138+
_trim_nz_builder!(args...)
11381139
return SparseVector(args[1], args[4], args[5])
11391140
end
11401141

src/solvers/umfpack.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,8 @@ function lu!(F::UmfpackLU{Tv, Ti}, S::AbstractSparseMatrixCSC;
471471
return lu!(F; reuse_symbolic, check, q)
472472
end
473473

474-
function lu!(F::UmfpackLU; check::Bool=true, reuse_symbolic::Bool=true, q=nothing)
474+
function lu!(F::UmfpackLU{Tv, Ti}; check::Bool=true, reuse_symbolic::Bool=true,
475+
q=nothing) where {Tv, Ti}
475476
if !reuse_symbolic && _isnotnull(F.symbolic)
476477
F.symbolic = Symbolic{Tv, Ti}(C_NULL)
477478
end

0 commit comments

Comments
 (0)