Skip to content

Commit 95b6ac4

Browse files
authored
Implement ldiv!() for QRSparse (#676)
I originally was going to implement a separate workspace like how FastLapackInterface.jl does it, but seeing as there's already a `QRSparse` type that kinda functions as a workspace already I decided it made more sense to re-use that. Reusing `QRSparse` is also consistent with the way that `UmfpackLU` holds a workspace. Unlike `UmfpackLU` I did not add a lock because for the example in the `workspace reuse` test it added a ~15% overhead. Instead the docstring clearly warns that using `QRSparse` with `ldiv!()` is not threadsafe and tells users to make an explicit copy. Preallocating `W` removed most of the allocations, and taking a view of `F.R` removed the other big one. Benchmarks: ```julia-repl using SparseArrays, LinearAlgebra m, n = 100, 10 nn = 100 A = sparse([1:n; rand(1:m, nn - n)], [1:n; rand(1:n, nn - n)], randn(nn), m, n) F = qr(A) b = randn(m) x = zeros(n); # Before julia> @benchmark $F \ $b BenchmarkTools.Trial: 10000 samples with 17 evaluations per sample. Range (min … max): 995.294 ns … 203.601 μs ┊ GC (min … max): 0.00% … 98.64% Time (median): 1.245 μs ┊ GC (median): 0.00% Time (mean ± σ): 1.324 μs ± 3.433 μs ┊ GC (mean ± σ): 6.15% ± 2.59% ▁▅▅▃▁ ▁▂▃▃▂ ▄▄▆█▇▆▅▄▂▁ ▁▁▁▁▂▁ ▂ ██████▇██████████████████▇███████▇▆▃▅▄▄▄▂▂▃▄▅▂▅▅▇▇▆▇██▇▇███▇▆ █ 995 ns Histogram: log(frequency) by time 1.81 μs < Memory estimate: 2.14 KiB, allocs estimate: 14. # After julia> @benchmark ldiv!($x, $F, $b) BenchmarkTools.Trial: 10000 samples with 15 evaluations per sample. Range (min … max): 957.467 ns … 3.822 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.154 μs ┊ GC (median): 0.00% Time (mean ± σ): 1.148 μs ± 128.500 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▃ █ ▁ ▅ ▄█▂▂▁▁▁▁▁▁▁▁▁▁▂█▅▃█▃▂█▆▂▆▄▁▂▂▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 957 ns Histogram: frequency by time 1.61 μs < Memory estimate: 96 bytes, allocs estimate: 2. ``` Fixes #242. Written with help from Claude 🤖
1 parent 4500d86 commit 95b6ac4

2 files changed

Lines changed: 161 additions & 61 deletions

File tree

src/solvers/spqr.jl

Lines changed: 121 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,30 @@ function _qr!(ordering::Integer, tol::Real, econ::Integer, getCTX::Integer,
108108
return rnk, _E, _HPinv
109109
end
110110

111+
struct QRSparseQ{Tv<:CHOLMOD.VTypes,Ti<:Integer} <: AbstractQ{Tv}
112+
factors::SparseMatrixCSC{Tv,Ti}
113+
τ::Vector{Tv}
114+
n::Int # Number of columns in original matrix
115+
end
116+
117+
Base.size(Q::QRSparseQ) = (size(Q.factors, 1), size(Q.factors, 1))
118+
Base.axes(Q::QRSparseQ) = map(Base.OneTo, size(Q))
119+
120+
Matrix{T}(Q::QRSparseQ) where {T} = lmul!(Q, Matrix{T}(I, size(Q, 1), min(size(Q, 1), Q.n)))
121+
111122
# Struct for storing sparse QR from SPQR such that
112123
# A[invperm(rpivinv), cpiv] = (I - factors[:,1]*τ[1]*factors[:,1]')*...*(I - factors[:,k]*τ[k]*factors[:,k]')*R
113124
# with k = size(factors, 2).
114125
struct QRSparse{Tv,Ti} <: LinearAlgebra.Factorization{Tv}
115126
factors::SparseMatrixCSC{Tv,Ti}
116127
τ::Vector{Tv}
117128
R::SparseMatrixCSC{Tv,Ti}
129+
Q::QRSparseQ{Tv,Ti}
118130
cpiv::Vector{Ti}
119131
rpivinv::Vector{Ti}
132+
133+
_lock::ReentrantLock
134+
_ldiv_workspace::Vector{Tv} # backing storage for work buffer (resizable)
120135
end
121136

122137
Base.size(F::QRSparse) = (size(F.factors, 1), size(F.R, 2))
@@ -133,17 +148,6 @@ function Base.size(F::QRSparse, i::Integer)
133148
end
134149
Base.axes(F::QRSparse) = map(Base.OneTo, size(F))
135150

136-
struct QRSparseQ{Tv<:CHOLMOD.VTypes,Ti<:Integer} <: AbstractQ{Tv}
137-
factors::SparseMatrixCSC{Tv,Ti}
138-
τ::Vector{Tv}
139-
n::Int # Number of columns in original matrix
140-
end
141-
142-
Base.size(Q::QRSparseQ) = (size(Q.factors, 1), size(Q.factors, 1))
143-
Base.axes(Q::QRSparseQ) = map(Base.OneTo, size(Q))
144-
145-
Matrix{T}(Q::QRSparseQ) where {T} = lmul!(Q, Matrix{T}(I, size(Q, 1), min(size(Q, 1), Q.n)))
146-
147151
# From SPQR manual p. 6
148152
_default_tol(A::AbstractSparseMatrixCSC) =
149153
20*sum(size(A))*eps()*maximum(norm(view(A, :, i)) for i in axes(A, 2))
@@ -155,6 +159,12 @@ Compute the `QR` factorization of a sparse matrix `A`. Fill-reducing row and col
155159
are used such that `F.R = F.Q'*A[F.prow,F.pcol]`. The main application of this type is to
156160
solve least squares or underdetermined problems with [`\\`](@ref). The function calls the C library SPQR[^ACM933].
157161
162+
!!! note
163+
The returned `QRSparse` object uses an internal workspace for
164+
[`ldiv!()`](@ref) calls that is protected by a lock for threadsafety. For
165+
multithreaded use, create a separate copy of this object for each task with
166+
`copy(F)`.
167+
158168
!!! note
159169
`qr(A::SparseMatrixCSC)` uses the SPQR library that is part of [SuiteSparse](https://github.com/DrTimothyAldenDavis/SuiteSparse).
160170
As this library only supports sparse matrices with [`Float64`](@ref) or
@@ -205,14 +215,19 @@ function LinearAlgebra.qr(A::SparseMatrixCSC{Tv, Ti}; tol=_default_tol(A), order
205215
R, E, H, HPinv, HTau)
206216

207217
R_ = SparseMatrixCSC{Tv, Ti}(Sparse(R[]))
208-
return QRSparse(SparseMatrixCSC{Tv, Ti}(Sparse(H[])),
209-
vec(Array{Tv}(CHOLMOD.Dense(HTau[]))),
210-
SparseMatrixCSC{Tv, Ti}(min(size(A)...),
211-
size(R_, 2),
212-
getcolptr(R_),
213-
rowvals(R_),
214-
nonzeros(R_)),
215-
p, hpinv)
218+
factors = SparseMatrixCSC{Tv, Ti}(Sparse(H[]))
219+
τ = vec(Array{Tv}(CHOLMOD.Dense(HTau[])))
220+
R = SparseMatrixCSC{Tv, Ti}(min(size(A)...),
221+
size(R_, 2),
222+
getcolptr(R_),
223+
rowvals(R_),
224+
nonzeros(R_))
225+
226+
return QRSparse(factors, τ, R,
227+
QRSparseQ(factors, τ, size(R, 2)),
228+
p, hpinv,
229+
ReentrantLock(),
230+
Tv[]) # _ldiv_workspace (lazily sized on first solve)
216231
end
217232
LinearAlgebra.qr(A::SparseMatrixCSC{Float16}; tol=_default_tol(A)) =
218233
qr(convert(SparseMatrixCSC{Float32}, A); tol=tol)
@@ -338,9 +353,7 @@ end
338353
(*)(A::SparseMatrixCSC, Q::QRSparseQ) = A * sparse(Q)
339354

340355
@inline function Base.getproperty(F::QRSparse, d::Symbol)
341-
if d === :Q
342-
return QRSparseQ(F.factors, F.τ, size(F, 2))
343-
elseif d === :prow
356+
if d === :prow
344357
return invperm(F.rpivinv)
345358
elseif d === :pcol
346359
return F.cpiv
@@ -354,6 +367,18 @@ function Base.propertynames(F::QRSparse, private::Bool=false)
354367
private ? ((public fieldnames(typeof(F)))...,) : public
355368
end
356369

370+
"""
371+
copy(F::QRSparse)
372+
373+
A shallow copy of QRSparse for use in multithreaded solve applications.
374+
Shares the factorization data but duplicates the workspace so that
375+
each copy can be used independently in a different thread.
376+
"""
377+
function Base.copy(F::QRSparse)
378+
QRSparse(F.factors, F.τ, F.R, F.Q, F.cpiv, F.rpivinv,
379+
ReentrantLock(), similar(F._ldiv_workspace))
380+
end
381+
357382
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::QRSparse)
358383
summary(io, F); println(io)
359384
println(io, "Q factor:")
@@ -406,49 +431,29 @@ function (\)(F::QRSparse{T}, B::VecOrMat{Complex{T}}) where T<:LinearAlgebra.Bla
406431
return collect(reshape(reinterpret(Complex{T}, copy(transpose(reshape(x, (length(x) >> 1), 2)))), _ret_size(F, B)))
407432
end
408433

409-
function _ldiv_basic(F::QRSparse, B::StridedVecOrMat)
410-
if size(F, 1) != size(B, 1)
411-
throw(DimensionMismatch("size(F) = $(size(F)) but size(B) = $(size(B))"))
412-
end
413-
414-
# The rank of F equal might be reduced
415-
rnk = rank(F)
416-
417-
# allocate an array for the return value large enough to hold B and X
418-
# For overdetermined problem, B is larger than X and vice versa
419-
X = similar(B, ntuple(i -> i == 1 ? max(size(F, 2), size(B, 1)) : size(B, 2), Val(ndims(B))))
434+
function _get_ldiv_workspace(F::QRSparse{Tv}, B::StridedVecOrMat) where Tv
435+
m, n = size(F)
436+
k = ndims(B) == 1 ? 1 : size(B, 2)
437+
wrows = max(m, n)
420438

421-
# Fill will zeros. These will eventually become the zeros in the basic solution
422-
# fill!(X, 0)
423-
# Apply left permutation to the solution and store in X
424-
for j in axes(B, 2)
425-
for i in 1:length(F.rpivinv)
426-
@inbounds X[F.rpivinv[i], j] = B[i, j]
427-
end
439+
# Resize backing vector if needed
440+
wlen = wrows * k
441+
if length(F._ldiv_workspace) != wlen
442+
resize!(F._ldiv_workspace, wlen)
428443
end
429444

430-
# Make a view into X corresponding to the size of B
431-
X0 = view(X, axes(B, 1), :)
432-
433-
# Apply Q' to B
434-
lmul!(adjoint(F.Q), X0)
435-
436-
# Zero out to get basic solution
437-
X[rnk + 1:end, :] .= 0
438-
439-
# Solve R*X = B
440-
ldiv!(UpperTriangular(F.R[Base.OneTo(rnk), Base.OneTo(rnk)]),
441-
view(X0, Base.OneTo(rnk), :))
445+
# Reshape into matrix. Note that we use ReshapedArray here instead of
446+
# reshape() to avoid allocations later when taking a view.
447+
W = Base.ReshapedArray(F._ldiv_workspace, (wrows, k), ())
448+
return W
449+
end
442450

443-
# Apply right permutation and extract solution from X
444-
# NB: cpiv == [] if SPQR was called with ORDERING_FIXED
445-
if length(F.cpiv) == 0
446-
return getindex(X, ntuple(i -> i == 1 ? (1:size(F,2)) : :, Val(ndims(B)))...)
447-
end
448-
return getindex(X, ntuple(i -> i == 1 ? invperm(F.cpiv) : :, Val(ndims(B)))...)
451+
function (\)(F::QRSparse{T}, B::StridedVecOrMat{T}) where {T}
452+
X = similar(B, ntuple(i -> i == 1 ? size(F, 2) : size(B, 2), Val(ndims(B))))
453+
# Note that we copy F here for thread-safety
454+
return ldiv!(X, copy(F), B)
449455
end
450456

451-
(\)(F::QRSparse{T}, B::StridedVecOrMat{T}) where {T} = _ldiv_basic(F, B)
452457
"""
453458
(\\)(F::QRSparse, B::StridedVecOrMat)
454459
@@ -473,4 +478,61 @@ julia> qr(A)\\fill(1.0, 4)
473478
"""
474479
(\)(F::QRSparse, B::StridedVecOrMat) = F\convert(AbstractArray{eltype(F)}, B)
475480

481+
function LinearAlgebra.ldiv!(X::StridedVecOrMat{T}, F::QRSparse{T}, B::StridedVecOrMat{T}) where {T}
482+
if size(F, 1) != size(B, 1)
483+
throw(DimensionMismatch("size(F) = $(size(F)) but size(B) = $(size(B))"))
484+
end
485+
if size(F, 2) != size(X, 1)
486+
throw(DimensionMismatch("size(F) = $(size(F)) but size(X) = $(size(X))"))
487+
end
488+
if ndims(B) > 1 && size(X, 2) != size(B, 2)
489+
throw(DimensionMismatch("size(X) = $(size(X)) but size(B) = $(size(B))"))
490+
end
491+
492+
rnk = rank(F)
493+
m = size(F, 1)
494+
n = size(F, 2)
495+
496+
@lock F._lock begin
497+
W = _get_ldiv_workspace(F, B)
498+
499+
# Apply left permutation to B and store in W
500+
for j in axes(B, 2)
501+
for i in 1:length(F.rpivinv)
502+
@inbounds W[F.rpivinv[i], j] = B[i, j]
503+
end
504+
end
505+
506+
# Make a view into W corresponding to the size of B
507+
W0 = @view W[Base.OneTo(m), :]
508+
509+
# Apply Q' to permuted B
510+
lmul!(adjoint(F.Q), W0)
511+
512+
# Solve R*X = Q'*P*B
513+
ldiv!(UpperTriangular(@view(F.R[Base.OneTo(rnk), Base.OneTo(rnk)])),
514+
@view(W0[Base.OneTo(rnk), :]))
515+
516+
# Apply right permutation: scatter solved rows into X using cpiv directly.
517+
# Zero X first so free variables (beyond rank) are zero in the basic solution.
518+
# NB: cpiv == [] if SPQR was called with ORDERING_FIXED
519+
fill!(X, zero(T))
520+
if length(F.cpiv) == 0
521+
for j in axes(W, 2)
522+
for i in 1:rnk
523+
@inbounds X[i, j] = W[i, j]
524+
end
525+
end
526+
else
527+
for j in axes(W, 2)
528+
for i in 1:rnk
529+
@inbounds X[F.cpiv[i], j] = W[i, j]
530+
end
531+
end
532+
end
533+
end
534+
535+
return X
536+
end
537+
476538
end # module

test/spqr.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ else
99

1010
using SparseArrays.SPQR
1111
using SparseArrays.CHOLMOD
12-
using LinearAlgebra: I, istriu, norm, qr, rank, rmul!, lmul!, Adjoint, Transpose, ColumnNorm, RowMaximum, NoPivot
12+
using LinearAlgebra: I, istriu, norm, qr, rank, rmul!, lmul!, ldiv!, Adjoint, Transpose, ColumnNorm, RowMaximum, NoPivot
1313
using SparseArrays: SparseArrays, sparse, sprandn, spzeros, SparseMatrixCSC
1414
using Random: seed!
1515

@@ -150,7 +150,7 @@ end
150150
A = sparse([0.0 1 0 0; 0 0 0 0])
151151
F = qr(A)
152152
@test propertynames(F) == (:R, :Q, :prow, :pcol)
153-
@test propertynames(F, true) == (:R, :Q, :prow, :pcol, :factors, , :cpiv, :rpivinv)
153+
@test propertynames(F, true) == (:R, :Q, :prow, :pcol, :factors, , :cpiv, :rpivinv, :_lock, :_ldiv_workspace)
154154
end
155155

156156
@testset "rank" begin
@@ -180,6 +180,44 @@ end
180180
@test V' * Dq.Q V' * Matrix(Dq.Q)
181181
end
182182

183+
@testset "ldiv!" begin
184+
@testset "workspace reuse" begin
185+
A = sparse([1:n; rand(1:m, nn - n)], [1:n; rand(1:n, nn - n)], randn(nn), m, n)
186+
F = qr(A)
187+
b = randn(m)
188+
x = zeros(n)
189+
190+
# First call will allocate the workspace
191+
first_allocs = @allocated ldiv!(x, F, b)
192+
@test length(F._ldiv_workspace) > 0
193+
@test x Array(A) \ b
194+
195+
# Second call with same-sized RHS should reuse workspace
196+
b2 = randn(m)
197+
second_allocs = @allocated ldiv!(x, F, b2)
198+
@test second_allocs < first_allocs
199+
@test x Array(A) \ b2
200+
end
201+
202+
@testset "dimension errors" begin
203+
A = sprandn(m, n, 0.5)
204+
F = qr(A)
205+
@test_throws DimensionMismatch ldiv!(zeros(n), F, zeros(m - 1))
206+
@test_throws DimensionMismatch ldiv!(zeros(n - 1), F, zeros(m))
207+
@test_throws DimensionMismatch ldiv!(zeros(n, 2), F, zeros(m, 3))
208+
end
209+
210+
@testset "copying QRSparse" begin
211+
A = sprandn(m, n, 0.5)
212+
F = qr(A)
213+
F_copy = copy(F)
214+
215+
# These fields must not be shared
216+
@test F._lock !== F_copy._lock
217+
@test F._ldiv_workspace !== F_copy._ldiv_workspace
218+
end
219+
end
220+
183221
@testset "no strategies" begin
184222
A = I + sprandn(10, 10, 0.1)
185223
for i in [ColumnNorm, RowMaximum, NoPivot]

0 commit comments

Comments
 (0)