Skip to content

Commit 28cdfc5

Browse files
committed
Rework initialize_linear_solver
This removes the `LP` property of the `SolverConfiguration` since all information is stored in the `linsolver`. The initializer is now triggered much later in the `solve`-call: after system matrix and right hand side are known. This is important when restrictions are in use. Moreover, we do not need to assemble a `u` vector for the linear solver. This is done in the init-call.
1 parent 74497d7 commit 28cdfc5

2 files changed

Lines changed: 35 additions & 37 deletions

File tree

src/solver_config.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ mutable struct SolverConfiguration{AT <: AbstractMatrix, bT, xT}
1414
tempsol::xT ## temporary solution
1515
res::xT
1616
freedofs::Vector{Int} ## stores indices of free dofs
17-
LP::LinearProblem
1817
statistics::Dict{Symbol, Any}
1918
linsolver::Any
2019
unknown_ids_in_sol::Array{Int, 1}
@@ -221,11 +220,5 @@ function SolverConfiguration(Problem::ProblemDescription, unknowns::Array{Unknow
221220
x_temp = x
222221
end
223222

224-
## construct linear problem
225-
if length(freedofs) > 0
226-
LP = LinearProblem(A.entries.cscmatrix[freedofs, freedofs], b.entries[freedofs])
227-
else
228-
LP = LinearProblem(A.entries.cscmatrix, b.entries)
229-
end
230-
return SolverConfiguration{typeof(A), typeof(b), typeof(x)}(Problem, A, b, x, x_temp, res, freedofs, LP, default_statistics(TvM, TiM), nothing, unknown_ids_in_sol, unknowns, copy(unknowns), offsets, parameters)
223+
return SolverConfiguration{typeof(A), typeof(b), typeof(x)}(Problem, A, b, x, x_temp, res, freedofs, default_statistics(TvM, TiM), nothing, unknown_ids_in_sol, unknowns, copy(unknowns), offsets, parameters)
231224
end

src/solvers.jl

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ end
8484
8585
Initialize the linear solver for the given system.
8686
"""
87-
function init_linear_solver!(SC, A, timer, method_linear, precon_linear)
87+
function init_linear_solver!(SC, A, b, timer)
8888

8989
# TODO use the timer
9090
time_assembly = 0.0
@@ -98,11 +98,14 @@ function init_linear_solver!(SC, A, timer, method_linear, precon_linear)
9898
stats = @timed begin
9999
abstol = SC.parameters[:abstol]
100100
reltol = SC.parameters[:reltol]
101-
LP = SC.LP
101+
method_linear = SC.parameters[:method_linear]
102+
precon_linear = SC.parameters[:precon_linear]
103+
104+
LP = LinearProblem(A, b)
102105
if precon_linear !== nothing
103-
SC.linsolver = init(LP, method_linear; Pl = precon_linear(A.entries.cscmatrix), abstol = abstol, reltol = reltol)
106+
SC.linsolver = init(LP, method_linear; Pl = precon_linear(A), abstol, reltol)
104107
else
105-
SC.linsolver = init(LP, method_linear; abstol = abstol, reltol = reltol)
108+
SC.linsolver = init(LP, method_linear; abstol, reltol)
106109
end
107110
end
108111
time_assembly += stats.time
@@ -242,11 +245,10 @@ Solves the linear system and updates the solution vector. This includes:
242245
- Computing the residual
243246
- Updating the solution with optional damping
244247
"""
245-
function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs...)
248+
function solve_linear_system!(A, b, sol, soltemp, residual, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs...)
246249

247250
@timeit timer "linear solver" begin
248251

249-
250252
# does the linsolve object need a (new) matrix?
251253
linsolve_needs_matrix = !SC.parameters[:constant_matrix] || !SC.parameters[:initialized]
252254

@@ -272,9 +274,9 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
272274

273275
if length(PD.restrictions) == 0
274276
if linsolve_needs_matrix
275-
linsolve.A = A_unrestricted
277+
linsolve_A = A_unrestricted
276278
end
277-
linsolve.b = b_unrestricted
279+
linsolve_b = b_unrestricted
278280
else
279281
# add possible Lagrange restrictions
280282
restriction_matrices = [length(freedofs) > 0 ? view(restriction_matrix(re), freedofs, :) : restriction_matrix(re) for re in PD.restrictions ]
@@ -342,11 +344,7 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
342344
end
343345

344346
b_block = BlockVector(zeros(Tv, total_size), block_sizes)
345-
346347
b_block[Block(1)] = b_unrestricted
347-
u_unrestricted = @views linsolve.u[1:block_sizes[1]]
348-
u_block = BlockVector(zeros(Tv, total_size), block_sizes)
349-
u_block[Block(1)] = u_unrestricted
350348

351349
for i in eachindex(restriction_matrices)
352350
if linsolve_needs_matrix
@@ -358,8 +356,7 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
358356
end
359357

360358
# convert to dense vectors
361-
linsolve.b = Vector(b_block)
362-
linsolve.u = Vector(u_block)
359+
linsolve_b = Vector(b_block)
363360

364361
if linsolve_needs_matrix
365362

@@ -379,12 +376,28 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
379376
# write each block directly in the resulting matrix
380377
A_flat[range_row, range_col] = A_block[Block(i, j)]
381378
end
382-
linsolve.A = A_flat
379+
linsolve_A = A_flat
383380
end
384381
end
385382
end
386383

387-
SC.parameters[:initialized] = true
384+
385+
if SC.parameters[:initialized]
386+
# set/update the linear system
387+
SC.linsolver.b = linsolve_b
388+
if linsolve_needs_matrix
389+
SC.linsolver.A = linsolve_A
390+
end
391+
else
392+
# init solver if not done before
393+
@timeit timer "linear solver" begin
394+
init_linear_solver!(SC, linsolve_A, linsolve_b, timer)
395+
end
396+
SC.parameters[:initialized] = true
397+
end
398+
399+
# now the linear solver is definitely ready
400+
linsolve = SC.linsolver
388401

389402
# Solve linear system
390403
push!(stats[:matrix_nnz], nnz(linsolve.A))
@@ -725,11 +738,6 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
725738
@info ".... isposdef = $(isposdef(A.entries.cscmatrix))"
726739
end
727740

728-
## init solver
729-
@timeit timer "linear solver" begin
730-
init_linear_solver!(SC, A, timer, method_linear, precon_linear)
731-
end
732-
733741
## compute nonlinear residual
734742
@timeit timer "assembly" @timeit timer "residual vector" begin
735743
nlres = compute_nonlinear_residual!(residual, A, b, sol, unknowns, PD, SC, freedofs)
@@ -744,11 +752,8 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
744752
break
745753
end
746754

747-
linsolve = SC.linsolver
748-
749-
750755
linres = solve_linear_system!(
751-
A, b, sol, soltemp, residual, linsolve, unknowns,
756+
A, b, sol, soltemp, residual, unknowns,
752757
freedofs, damping, PD, SC, stats, is_linear, timer
753758
)
754759

@@ -874,11 +879,11 @@ function iterate_until_stationarity(
874879
# Initialize linear solver if needed
875880
solve_init_time, solve_init_allocs = init_linear_solver!(
876881
SC,
877-
A,
878-
TimerOutput(),
879-
SC.parameters[:method_linear],
880-
SC.parameters[:precon_linear]
882+
A.entries.cscmatrix,
883+
b.entries,
884+
TimerOutput()
881885
)
886+
882887
time_solve_init += solve_init_time
883888
allocs_solve_init += solve_init_allocs
884889

0 commit comments

Comments
 (0)