Skip to content

Commit 483e5f4

Browse files
authored
Merge pull request #213 from JordiManyer/find-rcv-ids
Setting default algorithm to discover the communication graph
2 parents 8138104 + 3512ebc commit 483e5f4

9 files changed

Lines changed: 79 additions & 12 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
*.jl.mem
66
.DS_Store
77
Manifest.toml
8+
LocalPreferences.toml
89
docs/build/
910
tmp/
1011
docs/src/examples.md
12+
docs/src/jacobi_tutorial.md
1113

1214
HPCG/src/results/
1315

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.5.14] - 2025-09-04
9+
10+
- Added compile-time preference to choose the algorithm used within `default_rcv_ids`.
11+
812
## [0.5.13] - 2025-07-08
913

1014
### Fixed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PartitionedArrays"
22
uuid = "5a9dfac6-5c52-46f7-8278-5e2210713be9"
33
authors = ["Francesc Verdugo <f.verdugo.rojano@vu.nl> and contributors"]
4-
version = "0.5.13"
4+
version = "0.5.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -12,6 +12,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1212
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
15+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -26,6 +27,7 @@ Distances = "0.10"
2627
FillArrays = "0.10, 0.11, 0.12, 0.13, 1"
2728
IterativeSolvers = "0.9"
2829
MPI = "0.16, 0.17, 0.18, 0.19, 0.20"
30+
Preferences = "1"
2931
SparseMatricesCSR = "0.6"
3032
StaticArrays = "1"
3133
julia = "1.1"

docs/src/reference/backends.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
```@autodocs
66
Modules = [PartitionedArrays]
77
Pages = ["mpi_array.jl"]
8+
Filter = t -> string(t) != "find_rcv_ids_ibarrier"
89
```
910

1011
## Debug

docs/src/reference/primitives.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ gather
88
gather!
99
allocate_gather
1010
```
11+
1112
## Scatter
1213

1314
```@docs
@@ -45,7 +46,8 @@ ExchangeGraph(snd)
4546
exchange
4647
exchange!
4748
allocate_exchange
49+
default_find_rcv_ids
50+
set_default_find_rcv_ids
51+
find_rcv_ids_gather_scatter
52+
find_rcv_ids_ibarrier
4853
```
49-
50-
51-

src/PartitionedArrays.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ import IterativeSolvers
1111
import Distances
1212
using BlockArrays
1313
using Adapt
14+
using Preferences
15+
16+
export set_default_find_rcv_ids
17+
include("preferences.jl")
1418

1519
export length_to_ptrs!
1620
export rewind_ptrs!
@@ -54,6 +58,7 @@ export ExchangeGraph
5458
export exchange
5559
export exchange!
5660
export allocate_exchange
61+
export default_find_rcv_ids
5762
export find_rcv_ids_gather_scatter
5863
export setup_non_blocking_reduction
5964
export non_blocking_reduction

src/mpi_array.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,14 +660,23 @@ end
660660
Issend(data, dest::Integer, tag::Integer, comm::MPI.Comm, req=MPI.Request()) =
661661
Issend(MPI.Buffer_send(data), dest, tag, comm, req)
662662

663-
664663
function default_find_rcv_ids(::MPIArray)
665-
find_rcv_ids_gather_scatter
664+
@static if default_find_rcv_ids_algorithm == "gather_scatter"
665+
find_rcv_ids_gather_scatter
666+
elseif default_find_rcv_ids_algorithm == "ibarrier"
667+
find_rcv_ids_ibarrier
668+
else
669+
error("Unknown algorithm: $(default_find_rcv_ids_algorithm)")
670+
end
666671
end
667672

668673
"""
669-
Implements Alg. 2 in https://dl.acm.org/doi/10.1145/1837853.1693476
670-
The algorithm's complexity is claimed to be O(log(p))
674+
find_rcv_ids_ibarrier(snd_ids::MPIArray)
675+
676+
Finds the `rcv` side of an `ExchangeGraph` out of the `snd` side information.
677+
678+
This strategy implements Alg. 2 in https://dl.acm.org/doi/10.1145/1837853.1693476.
679+
The algorithm's complexity is claimed to be O(log(p)).
671680
"""
672681
function find_rcv_ids_ibarrier(snd_ids::MPIArray{<:AbstractVector{T}}) where T
673682
comm = snd_ids.comm

src/preferences.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
"""
3+
set_default_find_rcv_ids(algorithm::String)
4+
5+
Sets the default algorithm to discover communication neighbors. The available algorithms are:
6+
7+
- `gather_scatter`: Gathers neighbors in a single processor, builds the communications graph
8+
and then scatters the information back to all processors. See [`find_rcv_ids_gather_scatter`](@ref).
9+
10+
- `ibarrier`: Implements Alg. 2 in https://dl.acm.org/doi/10.1145/1837853.1693476. See [`find_rcv_ids_ibarrier`](@ref).
11+
12+
Feature only available in Julia 1.6 and later due to restrictions from `Preferences.jl`.
13+
"""
14+
function set_default_find_rcv_ids(algorithm::String)
15+
if !(algorithm in ("gather_scatter", "ibarrier"))
16+
throw(ArgumentError("Invalid algorihtm: \"$(algorithm)\""))
17+
end
18+
19+
# Set it in our runtime values, as well as saving it to disk
20+
@set_preferences!("default_find_rcv_ids" => algorithm)
21+
@info("New default algorithm set; restart your Julia session for this change to take effect!")
22+
end
23+
24+
@static if VERSION >= v"1.6"
25+
const default_find_rcv_ids_algorithm = @load_preference("default_find_rcv_ids", "gather_scatter")
26+
else
27+
const default_find_rcv_ids_algorithm = "gather_scatter"
28+
end

src/primitives.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,13 @@ function Base.show(io::IO,k::MIME"text/plain",data::ExchangeGraph)
762762
println(io,typeof(data)," with $(length(data.snd)) nodes")
763763
end
764764

765+
"""
766+
default_find_rcv_ids(::AbstractArray)
767+
768+
Provides a default function to find the `rcv` side of an
769+
`ExchangeGraph` out of the `snd` side information.
770+
Its behaviour can be statically changed using [`set_default_find_rcv_ids`](@ref).
771+
"""
765772
function default_find_rcv_ids(::AbstractArray)
766773
find_rcv_ids_gather_scatter
767774
end
@@ -779,10 +786,11 @@ are set to `snd`. Otherwise, either the optional `neighbors` or
779786
`neighbors` is also an `ExchangeGraph`
780787
that contains a super set of the outgoing and incoming neighbors
781788
associated with `snd`. It is used to find the incoming neighbors `rcv`
782-
efficiently. If `neighbors` are not provided, then `find_rcv_ids`
789+
efficiently. If `neighbors` are not provided, then `find_rcv_ids`
783790
is used (either the user-provided or a default one).
784791
`find_rcv_ids` is a function that implements an algorithm to find the
785-
rcv side of the exchange graph out of the snd side information.
792+
rcv side of the exchange graph out of the snd side information. It
793+
defaults to [`default_find_rcv_ids`](@ref).
786794
"""
787795
function ExchangeGraph(snd;
788796
rcv=nothing,
@@ -836,8 +844,14 @@ function ExchangeGraph_impl_with_find_rcv_ids(snd_ids::AbstractArray,find_rcv_id
836844
ExchangeGraph(snd_ids,rcv_ids)
837845
end
838846

839-
# This strategy gathers the communication graph into one process
840-
# and then scatters back the receivers
847+
"""
848+
find_rcv_ids_gather_scatter(snd_ids::AbstractArray)
849+
850+
Finds the `rcv` side of an `ExchangeGraph` out of the `snd` side information.
851+
852+
This strategy gathers the communication graph into one process
853+
and then scatters back the receivers.
854+
"""
841855
function find_rcv_ids_gather_scatter(snd_ids::AbstractArray)
842856
snd_ids_main = gather(snd_ids)
843857
rcv_ids_main = map(snd_ids_main) do snd_ids_main

0 commit comments

Comments
 (0)