@@ -343,8 +343,9 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
343343
344344 end
345345
346- linsolve. b = Vector (b_block) # convert to dense vector
347- linsolve. u = Vector (u_block) # convert to dense vector
346+ # convert to dense vectors
347+ linsolve. b = Vector (b_block)
348+ linsolve. u = Vector (u_block)
348349
349350 if linsolve_needs_matrix
350351
@@ -374,16 +375,53 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
374375 # Solve linear system
375376 push! (stats[:matrix_nnz ], nnz (linsolve. A))
376377 @timeit timer " solve! call" begin
377- blocked_result = LinearSolve. solve! (linsolve)
378- blocked_Δx = blocked_result. u
378+ LinearSolve. solve! (linsolve)
379379 end
380+ end
380381
381- # extract the solution / dismiss the lagrange multipliers
382- @views Δx = blocked_Δx[1 : length (b_unrestricted)]
382+ # Check linear residual
383+ @timeit timer " linear residual computation" begin
384+
385+ # compute flat residual (reuse b_flat): residual_flat = A_flat * u_flat - b_flat
386+ residual_flat = linsolve. b
387+ mul! (residual_flat, linsolve. A, linsolve. u, 1.0 , - 1.0 )
388+
389+ for op in PD. operators
390+ for dof in fixed_dofs (op)
391+ # fix dofs only in first block
392+ if dof <= length (b_unrestricted)
393+ residual_flat[dof] = 0
394+ end
395+ end
396+ end
397+
398+ if length (freedofs) > 0
399+ residual. entries[freedofs] .= @views residual_flat[1 : length (b_unrestricted)]
400+ else
401+ residual. entries .= @views residual_flat[1 : length (b_unrestricted)]
402+ end
403+
404+ if length (PD. restrictions) > 0
405+ # extract all residuals for the restriction blocks
406+ block_ends = cumsum (block_sizes)
407+ restriction_residuals = [norm (residual_flat[(block_ends[i] + 1 ): block_ends[i + 1 ]]) for i in 1 : (length (block_sizes) - 1 ) ]
408+ push! (stats[:restriction_residuals ], restriction_residuals)
409+ end
410+
411+ linres = norm (residual_flat)
412+ push! (stats[:linear_residuals ], linres)
413+ if is_linear
414+ push! (stats[:nonlinear_residuals ], linres)
415+ end
383416 end
384417
418+
385419 # Compute solution update
386420 @timeit timer " update solution" begin
421+
422+ # extract the solution / dismiss the lagrange multipliers
423+ @views Δx = linsolve. u[1 : length (b_unrestricted)]
424+
387425 if length (freedofs) > 0
388426 x = sol. entries[freedofs] + Δx
389427 else
@@ -398,36 +436,6 @@ function solve_linear_system!(A, b, sol, soltemp, residual, linsolve, unknowns,
398436 end
399437 end
400438
401- # Check linear residual
402- @timeit timer " linear residual computation" begin
403- residual. entries .= b. entries
404- if length (freedofs) > 0
405- soltemp. entries[freedofs] .= x
406- residual. entries .- = A. entries. cscmatrix * soltemp. entries
407- else
408- residual. entries .- = A. entries. cscmatrix * x
409- end
410- for op in PD. operators
411- for dof in fixed_dofs (op)
412- if dof <= length (residual. entries)
413- residual. entries[dof] = 0
414- end
415- end
416- end
417- for rs in PD. restrictions
418- for dof in fixed_dofs (rs)
419- if dof <= length (residual. entries)
420- residual. entries[dof] = 0
421- end
422- end
423- end
424- end
425-
426- linres = norm (residual. entries)
427- push! (stats[:linear_residuals ], linres)
428- if is_linear
429- push! (stats[:nonlinear_residuals ], linres)
430- end
431439
432440 # Update solution
433441 @timeit timer " update solution" begin
0 commit comments