8484
8585Initialize the linear solver for the given system.
8686"""
87- function init_linear_solver! (SC, A, timer, method_linear, precon_linear )
87+ function init_linear_solver! (SC, A, b, timer )
8888
8989 # TODO use the timer
9090 time_assembly = 0.0
@@ -98,11 +98,14 @@ function init_linear_solver!(SC, A, timer, method_linear, precon_linear)
9898 stats = @timed begin
9999 abstol = SC. parameters[:abstol ]
100100 reltol = SC. parameters[:reltol ]
101- LP = SC. LP
101+ method_linear = SC. parameters[:method_linear ]
102+ precon_linear = SC. parameters[:precon_linear ]
103+
104+ LP = LinearProblem (A, b)
102105 if precon_linear != = nothing
103- SC. linsolver = init (LP, method_linear; Pl = precon_linear (A. entries . cscmatrix ), abstol = abstol, reltol = reltol)
106+ SC. linsolver = init (LP, method_linear; Pl = precon_linear (A), abstol, reltol)
104107 else
105- SC. linsolver = init (LP, method_linear; abstol = abstol, reltol = reltol)
108+ SC. linsolver = init (LP, method_linear; abstol, reltol)
106109 end
107110 end
108111 time_assembly += stats. time
@@ -242,11 +245,10 @@ Solves the linear system and updates the solution vector. This includes:
242245- Computing the residual
243246- Updating the solution with optional damping
244247"""
245- function solve_linear_system! (A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs... )
248+ function solve_linear_system! (A, b, sol, soltemp, residual, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer, kwargs... )
246249
247250 @timeit timer " linear solver" begin
248251
249-
250252 # does the linsolve object need a (new) matrix?
251253 linsolve_needs_matrix = ! SC. parameters[:constant_matrix ] || ! SC. parameters[:initialized ]
252254
@@ -272,9 +274,9 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
272274
273275 if length (PD. restrictions) == 0
274276 if linsolve_needs_matrix
275- linsolve . A = A_unrestricted
277+ linsolve_A = A_unrestricted
276278 end
277- linsolve . b = b_unrestricted
279+ linsolve_b = b_unrestricted
278280 else
279281 # add possible Lagrange restrictions
280282 restriction_matrices = [length (freedofs) > 0 ? view (restriction_matrix (re), freedofs, :) : restriction_matrix (re) for re in PD. restrictions ]
@@ -342,11 +344,7 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
342344 end
343345
344346 b_block = BlockVector (zeros (Tv, total_size), block_sizes)
345-
346347 b_block[Block (1 )] = b_unrestricted
347- u_unrestricted = @views linsolve. u[1 : block_sizes[1 ]]
348- u_block = BlockVector (zeros (Tv, total_size), block_sizes)
349- u_block[Block (1 )] = u_unrestricted
350348
351349 for i in eachindex (restriction_matrices)
352350 if linsolve_needs_matrix
@@ -358,8 +356,7 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
358356 end
359357
360358 # convert to dense vectors
361- linsolve. b = Vector (b_block)
362- linsolve. u = Vector (u_block)
359+ linsolve_b = Vector (b_block)
363360
364361 if linsolve_needs_matrix
365362
@@ -379,12 +376,28 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
379376 # write each block directly in the resulting matrix
380377 A_flat[range_row, range_col] = A_block[Block (i, j)]
381378 end
382- linsolve . A = A_flat
379+ linsolve_A = A_flat
383380 end
384381 end
385382 end
386383
387- SC. parameters[:initialized ] = true
384+
385+ if SC. parameters[:initialized ]
386+ # set/update the linear system
387+ SC. linsolver. b = linsolve_b
388+ if linsolve_needs_matrix
389+ SC. linsolver. A = linsolve_A
390+ end
391+ else
392+ # init solver if not done before
393+ @timeit timer " linear solver" begin
394+ init_linear_solver! (SC, linsolve_A, linsolve_b, timer)
395+ end
396+ SC. parameters[:initialized ] = true
397+ end
398+
399+ # now the linear solver is definitely ready
400+ linsolve = SC. linsolver
388401
389402 # Solve linear system
390403 push! (stats[:matrix_nnz ], nnz (linsolve. A))
@@ -725,11 +738,6 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
725738 @info " .... isposdef = $(isposdef (A. entries. cscmatrix)) "
726739 end
727740
728- # # init solver
729- @timeit timer " linear solver" begin
730- init_linear_solver! (SC, A, timer, method_linear, precon_linear)
731- end
732-
733741 # # compute nonlinear residual
734742 @timeit timer " assembly" @timeit timer " residual vector" begin
735743 nlres = compute_nonlinear_residual! (residual, A, b, sol, unknowns, PD, SC, freedofs)
@@ -744,11 +752,8 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
744752 break
745753 end
746754
747- linsolve = SC. linsolver
748-
749-
750755 linres = solve_linear_system! (
751- A, b, sol, soltemp, residual, linsolve, unknowns,
756+ A, b, sol, soltemp, residual, unknowns,
752757 freedofs, damping, PD, SC, stats, is_linear, timer
753758 )
754759
@@ -874,11 +879,11 @@ function iterate_until_stationarity(
874879 # Initialize linear solver if needed
875880 solve_init_time, solve_init_allocs = init_linear_solver! (
876881 SC,
877- A,
878- TimerOutput (),
879- SC. parameters[:method_linear ],
880- SC. parameters[:precon_linear ]
882+ A. entries. cscmatrix,
883+ b. entries,
884+ TimerOutput ()
881885 )
886+
882887 time_solve_init += solve_init_time
883888 allocs_solve_init += solve_init_allocs
884889
0 commit comments