Skip to content

Commit 882efc1

Browse files
committed
Parallel periodic coupling
1 parent 27d9dbe commit 882efc1

7 files changed

Lines changed: 120 additions & 54 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ jobs:
5050
${{ runner.os }}-
5151
- uses: julia-actions/julia-buildpkg@v1
5252
- uses: julia-actions/julia-runtest@v1
53-
# env:
54-
# JULIA_NUM_THREADS: 4
53+
env:
54+
JULIA_NUM_THREADS: 4
5555
- uses: julia-actions/julia-processcoverage@v1
5656
- uses: codecov/codecov-action@v5
5757
docs:

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# CHANGES
22

3+
## v1.5.0
4+
5+
### Added
6+
7+
- `compute_periodic_coupling_matrix` can be computed thread parallel
38

49
## v1.4.0
510

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ExtendableFEM"
22
uuid = "a722555e-65e0-4074-a036-ca7ce79a4aed"
3-
version = "1.4"
3+
version = "1.5.0"
44
authors = ["Christian Merdon <merdon@wias-berlin.de>", "Patrick Jaap <patrick.jaap@wias-berlin.de>"]
55

66
[deps]
7+
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
78
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
89
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -18,12 +19,14 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1819
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
22+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2123
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2224
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2325
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2426

2527
[compat]
2628
Aqua = "0.8"
29+
ChunkSplitters = "3.1.2"
2730
CommonSolve = "0.2"
2831
DiffResults = "1"
2932
DocStringExtensions = "0.8,0.9"

examples/Example212_PeriodicBoundary2D.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ function main(;
121121
h = 1.0e-2,
122122
width = 6.0,
123123
height = 1.0,
124+
threads = 1,
124125
kwargs...
125126
)
126127
xgrid = create_grid(; h, width, height)
@@ -153,8 +154,7 @@ function main(;
153154
return nothing
154155
end
155156

156-
coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!)
157-
display(coupling_matrix)
157+
@time coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
158158
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
159159
end
160160

@@ -172,8 +172,10 @@ end
172172

173173
generateplots = ExtendableFEM.default_generateplots(Example212_PeriodicBoundary2D, "example212.png") #hide
174174
function runtests() #hide
175-
sol, plt = main() #hide
175+
sol, _ = main() #hide
176176
@test abs(maximum(view(sol[1])) - 1.3447465095618172) < 1.0e-3 #hide
177+
sol2, _ = main(threads = 4) #hide
178+
@test sol.entries sol2.entries #hide
177179
return nothing #hide
178180
end #hide
179181

examples/Example312_PeriodicBoundary3D.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ function main(;
136136
width = 6.0,
137137
height = 0.2,
138138
depth = 1,
139+
threads = 1,
139140
kwargs...
140141
)
141142

@@ -168,8 +169,7 @@ function main(;
168169
y[1] = width - x[1]
169170
return nothing
170171
end
171-
coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!)
172-
display(coupling_matrix)
172+
@time coupling_matrix = get_periodic_coupling_matrix(FES, reg_left, reg_right, give_opposite!; parallel = threads > 1, threads)
173173
assign_operator!(PD, CombineDofs(u, u, coupling_matrix; kwargs...))
174174
end
175175

@@ -185,8 +185,10 @@ end
185185

186186
generateplots = ExtendableFEM.default_generateplots(Example312_PeriodicBoundary3D, "example312.png") #hide
187187
function runtests() #hide
188-
sol, plt = main() #hide
189-
@test abs(maximum(view(sol[1])) - 1.8004602502175202) < 2.0e-3 #hide
188+
sol, _ = main() #hide
189+
@test abs(maximum(view(sol[1])) - 1.8004602502175202) < 2.0e-3 #hide
190+
sol2, _ = main(threads = 4) #hide
191+
@test sol.entries sol2.entries #hide
190192
return nothing #hide
191193
end #hide
192194

src/ExtendableFEM.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ $(read(joinpath(@__DIR__, "..", "README.md"), String))
55
"""
66
module ExtendableFEM
77

8+
using ChunkSplitters: chunks
89
using CommonSolve: CommonSolve
910
using DiffResults: DiffResults
1011
using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDSIGNATURES
@@ -68,6 +69,7 @@ using LinearSolve: LinearSolve, LinearProblem, UMFPACKFactorization, deleteat!,
6869
using Printf: Printf, @printf, @sprintf
6970
using SparseArrays: SparseArrays, AbstractSparseArray, SparseMatrixCSC, findnz, nnz,
7071
nzrange, rowvals, sparse, SparseVector
72+
using StaticArrays: @MArray
7173
using SparseDiffTools: SparseDiffTools, ForwardColorJacCache,
7274
forwarddiff_color_jacobian!, matrix_colors
7375
using Symbolics: Symbolics

src/helper_functions.jl

Lines changed: 96 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,29 @@ function get_periodic_coupling_matrix(
209209
return _get_periodic_coupling_matrix(FES, xgrid, b_from, b_to, give_opposite!; kwargs...)
210210
end
211211

212+
# merge matrix B into A, overriding the entries of A if an entry is both present in A and B
213+
function merge!(A::ExtendableSparseMatrix, B::ExtendableSparseMatrix)
214+
rows, cols, values = findnz(B)
215+
for (row, col, value) in zip(rows, cols, values)
216+
A[row, col] = value
217+
end
218+
return nothing
219+
end
220+
212221
function _get_periodic_coupling_matrix(
213222
FES::FESpace{Tv},
214223
xgrid::ExtendableGrid{TvG, TiG},
215224
b_from,
216225
b_to,
217226
give_opposite!::Function;
218227
mask = :auto,
219-
sparsity_tol = 1.0e-12
228+
sparsity_tol = 1.0e-12,
229+
parallel = false,
230+
threads = Threads.nthreads()
220231
) where {Tv, TvG, TiG}
221232

222-
@info "Computing periodic coupling matrix. This may take a while."
233+
nthr = parallel ? threads : 1
234+
@info "Computing periodic coupling matrix with $nthr thread(s). This may take a while."
223235

224236
if typeof(b_from) <: Int
225237
b_from = [b_from]
@@ -239,19 +251,6 @@ function _get_periodic_coupling_matrix(
239251
# FE basis dofs on each boundary face
240252
dofs_on_boundary = FES[BFaceDofs]
241253

242-
# fe vector used for interpolation
243-
fe_vector = FEVector(FES)
244-
# to be sure
245-
fill!(fe_vector.entries, 0.0)
246-
247-
# FE target vector for interpolation with sparse entries
248-
fe_vector_target = FEVector(FES; entries = SparseVector{Float64, Int64}(FES.ndofs, Int64[], Float64[]))
249-
250-
251-
# resulting sparse matrix
252-
n = length(fe_vector.entries)
253-
result = ExtendableSparseMatrix(n, n)
254-
255254
# face numbers of the boundary faces
256255
face_numbers_of_bfaces = xgrid[BFaceFaces]
257256

@@ -307,26 +306,18 @@ function _get_periodic_coupling_matrix(
307306
return
308307
end
309308

310-
eval_point, set_start = interpolate_on_boundaryfaces(fe_vector, xgrid, give_opposite!)
311-
312309
# precompute approximate search region for each boundary face in b_from
313310
searchareas = ExtendableGrids.VariableTargetAdjacency(TiG)
314311
coords = xgrid[Coordinates]
315312
facenodes = xgrid[FaceNodes]
316-
faces_to = zeros(Int, 1)
317313
coords_from = coords[:, facenodes[:, 1]]
318-
coords_to = coords[:, facenodes[:, 1]]
319314
nodes_per_faces = size(coords_from, 2)
320315
dim = size(coords_from, 1)
321-
box_from = [Float64[0, 0], Float64[0, 0], Float64[0, 0]]
322-
box_to = [Float64[0, 0], Float64[0, 0], Float64[0, 0]]
323-
nfaces_to = 0
316+
box_from = @MArray [Float64[0, 0], Float64[0, 0], Float64[0, 0]]
324317
for face_from in faces_in_b_from
325318
for j in 1:nodes_per_faces, k in 1:dim
326319
coords_from[k, j] = coords[k, facenodes[j, face_from]]
327320
end
328-
fill!(faces_to, 0)
329-
nfaces_to = 0
330321

331322
# transfer the coords_from to the other side
332323
transfer_face!(coords_from)
@@ -337,26 +328,44 @@ function _get_periodic_coupling_matrix(
337328
box_from[k][2] = maximum(view(coords_from, k, :))
338329
end
339330

340-
for face_to in faces_in_b_to
341-
for j in 1:nodes_per_faces, k in 1:dim
342-
coords_to[k, j] = coords[k, facenodes[j, face_to]]
343-
end
344-
for k in 1:dim
345-
box_to[k][1] = minimum(view(coords_to, k, :))
346-
box_to[k][2] = maximum(view(coords_to, k, :))
347-
end
331+
function inner_loop(faces_chunk)
332+
# some data
333+
local coords_to = coords[:, facenodes[:, 1]]
334+
local box_to = @MArray [Float64[0, 0], Float64[0, 0], Float64[0, 0]]
335+
local faces_to = Int[]
336+
337+
for face_to in faces_chunk
338+
for j in 1:nodes_per_faces, k in 1:dim
339+
coords_to[k, j] = coords[k, facenodes[j, face_to]]
340+
end
341+
for k in 1:dim
342+
box_to[k][1] = minimum(view(coords_to, k, :))
343+
box_to[k][2] = maximum(view(coords_to, k, :))
344+
end
348345

349-
if do_boxes_overlap(box_from, box_to)
350-
nfaces_to += 1
351-
if length(faces_to) < nfaces_to
346+
if do_boxes_overlap(box_from, box_to)
352347
push!(faces_to, face_to)
353-
else
354-
faces_to[nfaces_to] = face_to
355348
end
356349
end
350+
351+
return faces_to
357352
end
358353

359-
append!(searchareas, view(faces_to, 1:nfaces_to))
354+
if parallel && nthr > 1
355+
# create chunks to split this range for the threads
356+
faces_chunks = chunks(faces_in_b_to, n = nthr)
357+
358+
tasks = map(faces_chunks) do faces_chunk
359+
Threads.@spawn inner_loop(faces_chunk)
360+
end
361+
362+
# put all results together
363+
faces_to = vcat(fetch.(tasks)...)
364+
else
365+
faces_to = inner_loop(faces_in_b_to)
366+
end
367+
368+
append!(searchareas, faces_to)
360369
end
361370

362371
# throw error if no search area had been found for a bface
@@ -366,13 +375,31 @@ function _get_periodic_coupling_matrix(
366375
end
367376
end
368377

369-
# loop over boundary face indices: we need this index for dofs_on_boundary
370-
for i_boundary_face in 1:n_boundary_faces
378+
# we are only interest in global bface numbers on the "from" boundary
379+
bfaces_of_interest = filter(face -> boundary_regions[face] in b_from, 1:n_boundary_faces)
380+
n_bface_of_interest = length(bfaces_of_interest)
381+
382+
# loop over boundary face indices in a chunk: we need this index for dofs_on_boundary
383+
function compute_chunk_result(chunk)
384+
385+
# prepare data for this chunk
386+
# we need our own copy of the FE Space to avoid data race in the pre-computed interpolators
387+
our_FES = deepcopy(FES)
388+
local fe_vector = FEVector(our_FES)
389+
# to be sure
390+
fill!(fe_vector.entries, 0.0)
371391

372-
# for each boundary face: check if in b_from
373-
if boundary_regions[i_boundary_face] in b_from
392+
local entries = SparseVector{Float64, Int64}(our_FES.ndofs, Int64[], Float64[])
393+
local fe_vector_target = FEVector(our_FES; entries)
374394

375-
local_dofs = @views dofs_on_boundary[:, i_boundary_face]
395+
local n = length(fe_vector.entries)
396+
local result = ExtendableSparseMatrix(n, n)
397+
398+
local eval_point, _ = interpolate_on_boundaryfaces(fe_vector, xgrid, give_opposite!)
399+
400+
for i_boundary_face in chunk
401+
402+
local local_dofs = @views dofs_on_boundary[:, i_boundary_face]
376403
for local_dof in local_dofs
377404
# compute number of component
378405
if mask[1 + ((local_dof - 1) ÷ coffset)] == 0.0
@@ -388,7 +415,7 @@ function _get_periodic_coupling_matrix(
388415

389416
# interpolate on the opposite boundary using x_trafo = give_opposite
390417

391-
j = findfirst(==(face_numbers_of_bfaces[i_boundary_face]), faces_in_b_from)
418+
local j = findfirst(==(face_numbers_of_bfaces[i_boundary_face]), faces_in_b_from)
392419
if j <= 0
393420
throw("face $(face_numbers_of_bfaces[i_boundary_face]) is not on source boundary. Are the from/to regions and the give_opposite function correct?")
394421
end
@@ -410,6 +437,31 @@ function _get_periodic_coupling_matrix(
410437
end
411438
end
412439
end
440+
441+
return result
442+
end
443+
444+
if parallel && nthr > 1
445+
# we loop ober the n_boundary_faces in parallel:
446+
# create chunks to split this range for the threads
447+
bface_chunks = chunks(bfaces_of_interest, n = nthr)
448+
449+
# show start all chunks in parallel
450+
tasks = map(bface_chunks) do chunk
451+
Threads.@spawn compute_chunk_result(chunk)
452+
end
453+
454+
# @info "done..."
455+
# wait for all chunks to finish and get results
456+
results = fetch.(tasks)
457+
458+
# merge all matrices
459+
result = results[begin]
460+
for res_i in results[(begin + 1):end]
461+
merge!(result, res_i)
462+
end
463+
else
464+
result = compute_chunk_result(bfaces_of_interest)
413465
end
414466

415467
sp_result = sparse(result)

0 commit comments

Comments
 (0)