diff --git a/src/solvers.jl b/src/solvers.jl index 1fa01dae..6def518e 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -1,14 +1,165 @@ """ -```` -get_unknown_id(SC::SolverConfiguration, u::Unknown) -```` - -returns the id of the unknown u in SC + get_unknown_id(SC::SolverConfiguration, u::Unknown) +Returns the id of the unknown u in SC. """ function get_unknown_id(SC::SolverConfiguration, u::Unknown) return findfirst(==(u), SC.unknowns) end +""" + update_solution!(sol, x, unknowns, freedofs, damping) + +Update the solution vector with the new values, including optional damping. +""" +function update_solution!(sol, x, unknowns, freedofs, damping) + if length(freedofs) > 0 + sol.entries[freedofs] .= x + else + offset = 0 + for u in unknowns + ndofs_u = length(view(sol[u])) + if damping > 0 + view(sol[u]) .= damping * view(sol[u]) + (1 - damping) * view(x, (offset + 1):(offset + ndofs_u)) + else + view(sol[u]) .= view(x, (offset + 1):(offset + ndofs_u)) + end + offset += ndofs_u + end + end + + return nothing +end + +""" + compute_nonlinear_residual!(residual, A, b, sol, unknowns, PD, SC, freedofs) + +Compute the nonlinear residual for the current solution. +""" +function compute_nonlinear_residual!(residual, A, b, sol, unknowns, PD, SC, freedofs) + fill!(residual.entries, 0) + for j in 1:length(b), k in 1:length(b) + addblock_matmul!(residual[j], A[j, k], sol[unknowns[k]]) + end + residual.entries .-= b.entries + + for op in PD.operators + residual.entries[fixed_dofs(op)] .= 0 + end + + for u_off in SC.parameters[:inactive] + j = get_unknown_id(SC, u_off) + if j > 0 + fill!(residual[j], 0) + end + end + + nlres = length(freedofs) > 0 ? norm(residual.entries[freedofs]) : norm(residual.entries) + + if SC.parameters[:verbosity] > 0 && length(residual) > 1 + @info "sub-residuals = $(norms(residual))" + end + + return nlres +end + +""" + init_linear_solver!(SC, A, timer, method_linear, precon_linear) + +Initialize the linear solver for the given system. +""" +function init_linear_solver!(SC, A, timer, method_linear, precon_linear) + + # TODO use the timer + time_assembly = 0.0 + allocs_assembly = 0 + + if SC.linsolver === nothing + if SC.parameters[:verbosity] > 0 + @info ".... initializing linear solver ($(method_linear))\n" + end + @timeit timer "initialization" begin + stats = @timed begin + abstol = SC.parameters[:abstol] + reltol = SC.parameters[:reltol] + LP = SC.LP + if precon_linear !== nothing + SC.linsolver = init(LP, method_linear; Pl = precon_linear(A.entries.cscmatrix), abstol = abstol, reltol = reltol) + else + SC.linsolver = init(LP, method_linear; abstol = abstol, reltol = reltol) + end + end + time_assembly += stats.time + allocs_assembly += stats.bytes + end + end + return time_assembly, allocs_assembly +end + +""" + assemble_system!(A, b, sol, PD, SC, timer; kwargs...) + +Assemble the system matrix and right-hand side for the given problem. +""" +function assemble_system!(A, b, sol, PD, SC, timer; kwargs...) + + # TODO use the timer + time_assembly = 0.0 + allocs_assembly = 0 + + if !SC.parameters[:constant_rhs] + fill!(b.entries, 0) + end + if !SC.parameters[:constant_matrix] + fill!(A.entries.cscmatrix.nzval, 0) + end + + # Assemble operators + if SC.parameters[:initialized] + for op in PD.operators + @timeit timer "$(op.parameters[:name])" begin + stats = @timed assemble!( + A, b, sol, op, SC; + time = SC.parameters[:time], + assemble_matrix = !SC.parameters[:constant_matrix], + assemble_rhs = !SC.parameters[:constant_rhs], + kwargs... + ) + end + time_assembly += stats.time + allocs_assembly += stats.bytes + end + else + for op in PD.operators + @timeit timer "$(op.parameters[:name]) (first)" begin + stats = @timed assemble!( + A, b, sol, op, SC; + time = SC.parameters[:time], + kwargs... + ) + end + time_assembly += stats.time + allocs_assembly += stats.bytes + end + end + flush!(A.entries) + + # Apply penalties + for op in PD.operators + @timeit timer "$(op.parameters[:name]) (penalties)" begin + stats = @timed apply_penalties!( + A, b, sol, op, SC; + assemble_matrix = !SC.parameters[:initialized] || !SC.parameters[:constant_matrix], + assemble_rhs = !SC.parameters[:initialized] || !SC.parameters[:constant_rhs], + kwargs... + ) + end + time_assembly += stats.time + allocs_assembly += stats.bytes + end + flush!(A.entries) + + return time_assembly, allocs_assembly +end """ @@ -61,21 +212,262 @@ function symmetrize_structure!(A::ExtendableSparseMatrix{Tv, Ti}; diagval = 1.0e return flush!(A) end -function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{<:FESpace}}, SC = nothing; unknowns = PD.unknowns, kwargs...) - if typeof(FES) <: FESpace - FES = [FES] +""" + solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer) + +Solves the linear system and updates the solution vector. This includes: +- Setting up the system matrix and right-hand side +- Solving the linear system +- Computing the residual +- Updating the solution with optional damping +""" +function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns, freedofs, damping, PD, SC, stats, is_linear, timer) + # Update system matrix if needed + if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] + if length(freedofs) > 0 + linsolve.A = A.entries.cscmatrix[freedofs, freedofs] + else + linsolve.A = A.entries.cscmatrix + end + end + + # Set right-hand side + if length(freedofs) > 0 + linsolve.b = residual.entries[freedofs] + else + linsolve.b = residual.entries + end + SC.parameters[:initialized] = true + + # Solve linear system + push!(stats[:matrix_nnz], nnz(linsolve.A)) + @timeit timer "solve! call" Δx = LinearSolve.solve!(linsolve) + + # Compute solution update + @timeit timer "update solution" begin + if length(freedofs) > 0 + x = sol.entries[freedofs] - Δx.u + else + x = zero(Δx) + offset = 0 + for u in unknowns + ndofs_u = length(view(sol[u])) + x_range = (offset + 1):(offset + ndofs_u) + x[x_range] .= view(sol[u]) .- view(Δx, x_range) + offset += ndofs_u + end + end + end + + # Check linear residual + @timeit timer "linear residual computation" begin + if length(freedofs) > 0 + soltemp.entries[freedofs] .= x + residual.entries .= A.entries.cscmatrix * soltemp.entries + else + residual.entries .= A.entries.cscmatrix * x + end + residual.entries .-= b.entries + for op in PD.operators + for dof in fixed_dofs(op) + if dof <= length(residual.entries) + residual.entries[dof] = 0 + end + end + end + end + + linres = norm(residual.entries) + push!(stats[:linear_residuals], linres) + if is_linear + push!(stats[:nonlinear_residuals], linres) + end + + # Update solution + @timeit timer "update solution" begin + update_solution!(sol, x, unknowns, freedofs, damping) + end + + return linres +end + +""" + initialize_coupled_solution(FES, init, unknowns, nPDs) + +Initialize the solution vector and finite element spaces for coupled problems. +""" +function initialize_coupled_solution(FES, init, unknowns, nPDs) + if FES === nothing + @assert init !== nothing "need init vector or FES (as a Vector{Vector{<:FESpace}})" + @info ".... taking FESpaces from init vector \n" + all_unknowns = init.tags + for p in 1:nPDs, u in unknowns[p] + @assert u in all_unknowns "did not found unknown $u in init vector (tags missing?)" + end + FES = [[init[u].FES for u in unknowns[j]] for j in 1:nPDs] + sol = copy(init) + sol.tags .= init.tags + else + all_unknowns = [] + for p in 1:nPDs, u in unknowns[p] + if !(u in all_unknowns) + push!(u, all_unknowns) + end + end + sol = FEVector(FES; tags = all_unknowns) + end + return sol, FES +end + +""" + solve_coupled_system!(A, b, sol, residual, linsolve, unknowns, damping, PD, SC) + +Solves the coupled system for a single subproblem and updates the solution. +""" +function solve_coupled_system!(A, b, sol, residual, linsolve, unknowns, damping, PD, SC) + if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] + linsolve.A = A.entries.cscmatrix + end + + # we solve for A Δx = r + # and update x = sol - Δx + linsolve.b = residual.entries + SC.parameters[:initialized] = true + + ## solve + Δx = LinearSolve.solve!(linsolve) + + # x = sol.entries - Δx.u ... in the entry ranges of the present unknowns + x = zero(Δx.u) + offset = 0 + for u in unknowns + ndofs_u = length(view(sol[u])) + x_range = (offset + 1):(offset + ndofs_u) + x[x_range] .= view(sol[u]) .- view(Δx.u, x_range) + offset += ndofs_u + end + + fill!(residual.entries, 0) + mul!(residual.entries, A.entries.cscmatrix, x) + residual.entries .-= b.entries + for op in PD.operators + for dof in fixed_dofs(op) + if dof <= length(residual.entries) + residual.entries[dof] = 0 + end + end + end + + linres = norm(residual.entries) + + offset = 0 + for u in unknowns + ndofs_u = length(view(sol[u])) + if damping > 0 + view(sol[u]) .= damping * view(sol[u]) + (1 - damping) * view(x, (offset + 1):(offset + ndofs_u)) + else + view(sol[u]) .= view(x, (offset + 1):(offset + ndofs_u)) + end + offset += ndofs_u end + + return linres +end + +""" + check_problem_linearity!(PDs, SCs, unknowns, is_linear) + +Check the linearity of each subproblem and set appropriate flags. +""" +function check_problem_linearity!(PDs, SCs, unknowns) + nPDs = length(PDs) + is_linear = zeros(Bool, nPDs) + nonlinear = zeros(Bool, nPDs) + for (j, PD) in enumerate(PDs) + for op in PD.operators + nl_dependencies = depends_nonlinearly_on(op) + for u in unknowns + if u in nl_dependencies + nonlinear[j] = true + break + end + end + end + if SCs[j].parameters[:verbosity] > 0 + @info "nonlinear = $(nonlinear[j] ? "true" : "false")\n" + end + if SCs[j].parameters[:is_linear] == "auto" + is_linear[j] = !nonlinear[j] + end + if is_linear[j] && nonlinear[j] + @warn "problem $(PD.name) seems nonlinear, but user set is_linear = true (results may be wrong)!!" + end + end + return is_linear +end + +""" + init_solver_config(PD::ProblemDescription, FES::Union{<:FESpace, Vector{<:FESpace}}, SC, unknowns, kwargs) + +Initialize and configure the solver based on the problem description and finite element spaces. +""" +function init_solver_config(PD::ProblemDescription, FES::Union{<:FESpace, Vector{<:FESpace}}, SC, unknowns, kwargs) + FES_array = typeof(FES) <: FESpace ? [FES] : FES + if typeof(SC) <: SolverConfiguration _update_params!(SC.parameters, kwargs) - if SC.parameters[:verbosity] > 0 - @info ".... reusing given solver configuration\n" + SC.parameters[:verbosity] > 0 && @info ".... reusing given solver configuration\n" + else + SC = SolverConfiguration(PD, unknowns, FES_array; kwargs...) + SC.parameters[:verbosity] > 0 && @info ".... init solver configuration\n" + end + return SC, FES_array +end + + +function print_convergence_result(SC, is_linear, linres, nlres, nltol, j, maxits) + + stop = false + if nlres < nltol + if SC.parameters[:verbosity] > -1 + @printf " END\t" + @printf "%.3e\t" nlres + @printf "converged\n" end + stop = true + elseif isnan(nlres) + if SC.parameters[:verbosity] > -1 + @printf " END\t" + @printf "%.3e\t" nlres + @printf "not converged\n" + end + stop = true + elseif (j == maxits + 1) && !(is_linear) + if SC.parameters[:verbosity] > -1 + @printf " END\t" + @printf "\t\t%.3e\t" linres + @printf "maxiterations reached\n" + end + stop = true else - SC = SolverConfiguration(PD, unknowns, FES; kwargs...) - if SC.parameters[:verbosity] > 0 - @info ".... init solver configuration\n" + if SC.parameters[:verbosity] > -1 + if is_linear + @printf " END\t" + else + @printf "%4d\t" j + end + if !(is_linear) + @printf "%.3e\t" nlres + else + @printf "---------\t" + end end end + return stop +end + +function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{<:FESpace}}, SC = nothing; unknowns = PD.unknowns, kwargs...) + SC, FES = init_solver_config(PD, FES, SC, unknowns, kwargs) ## load TimerOutputs timer = timeroutputs(SC) @@ -99,7 +491,6 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{ maxits = SC.parameters[:maxiterations] @assert maxits > -1 nltol = SC.parameters[:target_residual] - is_linear = SC.parameters[:is_linear] damping = SC.parameters[:damping] freedofs = SC.freedofs @@ -122,25 +513,7 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{ end ## check if problem is (non)linear - nonlinear = false - for op in PD.operators - nl_dependencies = depends_nonlinearly_on(op) - for u in unknowns - if u in nl_dependencies - nonlinear = true - break - end - end - end - if SC.parameters[:verbosity] > 0 - @info " nonlinear = $(nonlinear ? "true" : "false")\n" - end - if is_linear == "auto" - is_linear = !nonlinear - end - if is_linear && nonlinear - @warn "problem seems nonlinear, but user set is_linear = true (results may be wrong)!!" - end + is_linear = check_problem_linearity!([PD], [SC], unknowns)[1] if is_linear maxits = 0 end @@ -151,66 +524,13 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{ end nlres = 1.1e30 linres = 1.1e30 - linsolve = SC.linsolver - reduced = false for j in 1:(maxits + 1) if is_linear && j == 2 nlres = linres else @timeit timer "assembly" begin - - ## assemble operators - if !SC.parameters[:constant_rhs] - fill!(b.entries, 0) - end - if !SC.parameters[:constant_matrix] - fill!(A.entries.cscmatrix.nzval, 0) - end - if SC.parameters[:initialized] - for op in PD.operators - @timeit timer "$(op.parameters[:name])" assemble!(A, b, sol, op, SC; time = SC.parameters[:time], assemble_matrix = !SC.parameters[:constant_matrix], assemble_rhs = !SC.parameters[:constant_rhs], kwargs...) - end - else - for op in PD.operators - @timeit timer "$(op.parameters[:name]) (first)" assemble!(A, b, sol, op, SC; time = SC.parameters[:time], kwargs...) - end - end - flush!(A.entries) - - ## penalize fixed dofs - for op in PD.operators - @timeit timer "$(op.parameters[:name]) (penalties)" apply_penalties!(A, b, sol, op, SC; assemble_matrix = !SC.parameters[:initialized] || !SC.parameters[:constant_matrix], assemble_rhs = !SC.parameters[:initialized] || !SC.parameters[:constant_rhs], kwargs...) - end - flush!(A.entries) - # end - - # ## remove inactive dofs - # for u_off in SC.parameters[:inactive] - # j = get_unknown_id(SC, u_off) - # if j > 0 - # fill!(A[j,j],0) - # FES = sol[j].FES - # for dof in 1:FES.ndofs - # A[j,j][dof, dof] = 1e60 - # b[j][dof] = 1e60*sol[j][dof] - # end - # else - # @warn "inactive unknown $(u_off) not part of unknowns, skipping this one..." - # end - # end - - ## reduction steps - # time_assembly += @elapsed begin - # if length(PD.reduction_operators) > 0 && j == 1 - # LP_reduced = SC.LP - # reduced = true - # for op in PD.reduction_operators - # allocs_assembly += @allocated LP_reduced, A, b = apply!(LP_reduced, op, SC; kwargs...) - # end - # residual = copy(b) - # end - # end + assemble_system!(A, b, sol, PD, SC, timer; kwargs...) end ## show spy @@ -225,178 +545,36 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{ @info ".... spy plot of system matrix:\n$(A.entries.cscmatrix))" end if SC.parameters[:check_matrix] - #λ, ϕ = Arpack.eigs(A.entries.cscmatrix; nev = 5, which = :SM, ritzvec = false) - #@info ".... 5 :SM eigs = $(λ)" - #λ, ϕ = Arpack.eigs(A.entries.cscmatrix; nev = 5, which = :LM, ritzvec = false) - #@info ".... 5 :LM eigs = $(λ)" @info ".... ||A - A'|| = $(norm(A.entries.cscmatrix - A.entries.cscmatrix', Inf))" @info ".... isposdef = $(isposdef(A.entries.cscmatrix))" end ## init solver @timeit timer "linear solver" begin - if linsolve === nothing - if SC.parameters[:verbosity] > 0 - @info ".... initializing linear solver ($(method_linear))\n" - end - @timeit timer "initialization" begin - abstol = SC.parameters[:abstol] - reltol = SC.parameters[:reltol] - LP = reduced ? LP_reduced : SC.LP - if precon_linear !== nothing - linsolve = init(LP, method_linear; Pl = precon_linear(A.entries.cscmatrix), abstol = abstol, reltol = reltol) - else - linsolve = init(LP, method_linear; abstol = abstol, reltol = reltol) - end - SC.linsolver = linsolve - end - end + init_linear_solver!(SC, A, timer, method_linear, precon_linear) end ## compute nonlinear residual @timeit timer "assembly" @timeit timer "residual vector" begin - fill!(residual.entries, 0) - for j in 1:length(b), k in 1:length(b) - addblock_matmul!(residual[j], A[j, k], sol[unknowns[k]]) - end - residual.entries .-= b.entries - #res = A.entries * sol.entries - b.entries - for op in PD.operators - residual.entries[fixed_dofs(op)] .= 0 - end - for u_off in SC.parameters[:inactive] - j = get_unknown_id(SC, u_off) - if j > 0 - fill!(residual[j], 0) - end - end - if length(freedofs) > 0 - nlres = norm(residual.entries[freedofs]) - else - nlres = norm(residual.entries) - end - if SC.parameters[:verbosity] > 0 && length(residual) > 1 - @info "sub-residuals = $(norms(residual))" - end + nlres = compute_nonlinear_residual!(residual, A, b, sol, unknowns, PD, SC, freedofs) end end if !is_linear push!(stats[:nonlinear_residuals], nlres) end - if nlres < nltol - if SC.parameters[:verbosity] > -1 - @printf " END\t" - @printf "%.3e\t" nlres - @printf "converged\n" - end - break - elseif isnan(nlres) - if SC.parameters[:verbosity] > -1 - @printf " END\t" - @printf "%.3e\t" nlres - @printf "not converged\n" - end - break - elseif (j == maxits + 1) && !(is_linear) - if SC.parameters[:verbosity] > -1 - @printf " END\t" - @printf "\t\t%.3e\t" linres - @printf "maxiterations reached\n" - end + + stop = print_convergence_result(SC, is_linear, linres, nlres, nltol, j, maxits) + if stop break - else - if SC.parameters[:verbosity] > -1 - if is_linear - @printf " END\t" - else - @printf "%4d\t" j - end - if !(is_linear) - @printf "%.3e\t" nlres - else - @printf "---------\t" - end - end end - @timeit timer "linear solver" begin - if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] - if length(freedofs) > 0 - linsolve.A = A.entries.cscmatrix[freedofs, freedofs] - else - linsolve.A = A.entries.cscmatrix - end - end - - # we solve for A Δx = r - # and update x = sol - Δx - if length(freedofs) > 0 - linsolve.b = residual.entries[freedofs] - else - linsolve.b = residual.entries - end - SC.parameters[:initialized] = true - - ## solve - push!(stats[:matrix_nnz], nnz(linsolve.A)) - @timeit timer "solve! call" Δx = LinearSolve.solve!(linsolve) + linsolve = SC.linsolver - # x = sol.entries - Δx.u for free dofs or partial solutions - @timeit timer "update solution" begin - if length(freedofs) > 0 - x = sol.entries[freedofs] - Δx.u - else - x = zero(Δx) - offset = 0 - for u in unknowns - ndofs_u = length(view(sol[u])) - x_range = (offset + 1):(offset + ndofs_u) - x[x_range] .= view(sol[u]) .- view(Δx, x_range) - offset += ndofs_u - end - end - end - - ## check linear residual with full matrix - @timeit timer "linear residual computation" begin - if length(freedofs) > 0 - soltemp.entries[freedofs] .= x - residual.entries .= A.entries.cscmatrix * soltemp.entries - else - residual.entries .= A.entries.cscmatrix * x - end - residual.entries .-= b.entries - for op in PD.operators - for dof in fixed_dofs(op) - if dof <= length(residual.entries) - residual.entries[dof] = 0 - end - end - end - end - linres = norm(residual.entries) - push!(stats[:linear_residuals], linres) - if is_linear - push!(stats[:nonlinear_residuals], linres) - end - - ## update solution (incl. damping etc.) - @timeit timer "update solution" begin - offset = 0 - if length(freedofs) > 0 - sol.entries[freedofs] .= x - else - for u in unknowns - ndofs_u = length(view(sol[u])) - if damping > 0 - view(sol[u]) .= damping * view(sol[u]) + (1 - damping) * view(x, (offset + 1):(offset + ndofs_u)) - else - view(sol[u]) .= view(x, (offset + 1):(offset + ndofs_u)) - end - offset += ndofs_u - end - end - end + @timeit timer "linear solver" begin + linres = solve_linear_system!( + A, b, sol, soltemp, residual, linsolve, unknowns, + freedofs, damping, PD, SC, stats, is_linear, timer + ) end if SC.parameters[:verbosity] > -1 if is_linear @@ -468,26 +646,8 @@ function iterate_until_stationarity( PDs::Array{ProblemDescription, 1} = [SC.PD for SC in SCs] nPDs = length(PDs) - ## find FESpaces and generate solution vector - if FES === nothing - @assert init !== nothing "need init vector or FES (as a Vector{Vector{<:FESpace}})" - @info ".... taking FESpaces from init vector \n" - all_unknowns = init.tags - for p in 1:nPDs, u in unknowns[p] - @assert u in all_unknowns "did not found unknown $u in init vector (tags missing?)" - end - FES = [[init[u].FES for u in unknowns[j]] for j in 1:nPDs] - sol = copy(init) - sol.tags .= init.tags - else - all_unknowns = [] - for p in 1:nPDs, u in unknowns[p] - if !(u in all_unknowns) - push!(u, all_unknowns) - end - end - sol = FEVector(FES; tags = all_unknowns) - end + # Initialize solution vector and FE spaces + sol, FES = initialize_coupled_solution(FES, init, unknowns, nPDs) @info "SOLVING iteratively $([PD.name for PD in PDs]) unknowns = $([[uj.name for uj in u] for u in unknowns])" @@ -498,31 +658,8 @@ function iterate_until_stationarity( bs = [SC.b for SC in SCs] residuals = [SC.res for SC in SCs] - ## unpack solver parameters - is_linear = zeros(Bool, nPDs) - - ## check if problems are (non)linear - nonlinear = zeros(Bool, nPDs) - for (j, PD) in enumerate(PDs) - for op in PD.operators - nl_dependencies = depends_nonlinearly_on(op) - for u in unknowns - if u in nl_dependencies - nonlinear[j] = true - break - end - end - end - if SCs[j].parameters[:verbosity] > 0 - @info "nonlinear = $(nonlinear[j] ? "true" : "false")\n" - end - if SCs[j].parameters[:is_linear] == "auto" - is_linear[j] = !nonlinear[j] - end - if is_linear[j] && nonlinear[j] - @warn "problem $(PD.name) seems nonlinear, but user set is_linear = true (results may be wrong)!!" - end - end + # Check linearity of each subproblem + is_linear = check_problem_linearity!(PDs, SCs, unknowns) maxits = [is_linear[j] ? 1 : maxits[j] for j in 1:nPDs] alloc_factor = 1024^2 @@ -553,64 +690,22 @@ function iterate_until_stationarity( damping = SC.parameters[:damping] for j in 1:1 time_total += @elapsed begin - - ## assemble operators - if !SC.parameters[:constant_rhs] - fill!(b.entries, 0) - end - if !SC.parameters[:constant_matrix] - fill!(A.entries.cscmatrix.nzval, 0) - end - if SC.parameters[:initialized] - time_assembly += @elapsed for op in PD.operators - allocs_assembly += @allocated assemble!(A, b, sol, op, SC; time = SC.parameters[:time], assemble_matrix = !SC.parameters[:constant_matrix], assemble_rhs = !SC.parameters[:constant_rhs], kwargs...) - end - else - time_assembly += @elapsed for op in PD.operators - allocs_assembly += @allocated assemble!(A, b, sol, op, SC; time = SC.parameters[:time], kwargs...) - end - end - flush!(A.entries) - - ## penalize fixed dofs - time_assembly += @elapsed for op in PD.operators - allocs_assembly += @allocated apply_penalties!(A, b, sol, op, SC; kwargs...) - end - flush!(A.entries) - - if SC.parameters[:verbosity] > 0 - @printf " assembly time | allocs = %.2f s | %.2f MiB\n" time allocs / alloc_factor - end - - ## show spy - if SC.parameters[:show_matrix] - @show A - elseif SC.parameters[:spy] - @info ".... spy plot of system matrix:\n$(UnicodePlots.spy(sparse(A.entries.cscmatrix)))" - end - - ## init solver - linsolve = SC.linsolver - if linsolve === nothing - if SC.parameters[:verbosity] > 0 - @info ".... initializing linear solver ($(method_linear))\n" - end - time_solve_init += @elapsed begin - allocs_solve_init += @allocated begin - method_linear = SC.parameters[:method_linear] - precon_linear = SC.parameters[:precon_linear] - abstol = SC.parameters[:abstol] - reltol = SC.parameters[:reltol] - LP = SC.LP - if precon_linear !== nothing - linsolve = LinearSolve.init(LP, method_linear; Pl = precon_linear(linsolve.A), abstol = abstol, reltol = reltol) - else - linsolve = LinearSolve.init(LP, method_linear; abstol = abstol, reltol = reltol) - end - SC.linsolver = linsolve - end - end - end + # Assemble system and update timing/allocation info + @show A b sol PD SC TimerOutput() + assembly_time, assembly_allocs = assemble_system!(A, b, sol, PD, SC, TimerOutput(); kwargs...) + time_assembly += assembly_time + allocs_assembly += assembly_allocs + + # Initialize linear solver if needed + solve_init_time, solve_init_allocs = init_linear_solver!( + SC, + A, + TimerOutput(), + SC.parameters[:method_linear], + SC.parameters[:precon_linear] + ) + time_solve_init += solve_init_time + allocs_solve_init += solve_init_allocs ## compute nonlinear residual fill!(residual.entries, 0) @@ -642,51 +737,7 @@ function iterate_until_stationarity( time_solve = @elapsed begin allocs_solve = @allocated begin - if !SC.parameters[:constant_matrix] || !SC.parameters[:initialized] - linsolve.A = A.entries.cscmatrix - end - - # we solve for A Δx = r - # and update x = sol - Δx - linsolve.b = residual.entries - SC.parameters[:initialized] = true - - - ## solve - Δx = LinearSolve.solve!(linsolve) - - # x = sol.entries - Δx.u ... in the entry ranges of the present unknowns - x = zero(Δx.u) - offset = 0 - for u in unknowns[p] - ndofs_u = length(view(sol[u])) - x_range = (offset + 1):(offset + ndofs_u) - x[x_range] .= view(sol[u]) .- view(Δx.u, x_range) - offset += ndofs_u - end - - fill!(residual.entries, 0) - mul!(residual.entries, A.entries.cscmatrix, x) - residual.entries .-= b.entries - for op in PD.operators - for dof in fixed_dofs(op) - if dof <= length(residual.entries) - residual.entries[dof] = 0 - end - end - end - #@info residual.entries, norms(residual) - linres = norm(residual.entries) - offset = 0 - for u in unknowns[p] - ndofs_u = length(view(sol[u])) - if damping > 0 - view(sol[u]) .= damping * view(sol[u]) + (1 - damping) * view(x, (offset + 1):(offset + ndofs_u)) - else - view(sol[u]) .= view(x, (offset + 1):(offset + ndofs_u)) - end - offset += ndofs_u - end + linres = solve_coupled_system!(A, b, sol, residual, SC.linsolver, unknowns[p], damping, PD, SC) end end time_total += time_solve