Skip to content

Commit 9d77747

Browse files
committed
Working Parallel BP
1 parent 9d937aa commit 9d77747

9 files changed

Lines changed: 545 additions & 0 deletions

File tree

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2929

3030
[weakdeps]
3131
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
32+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
33+
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
3234

3335
[extensions]
3436
ITensorNetworksNextTensorOperationsExt = "TensorOperations"
37+
ITensorNetworksNextDistributedExt = "Distributed"
38+
ITensorNetworksNextDaggerExt = "Dagger"
3539

3640
[compat]
3741
AbstractTrees = "0.4.5"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
module ITensorNetworksNextDaggerExt
2+
3+
using Dagger
4+
using Dagger.Distributed
5+
using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerNestedAlgorithm, DaggerState,
6+
ITensorNetworksNextParallel
7+
8+
import AlgorithmsInterface as AI
9+
import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE
10+
11+
12+
function ITensorNetworksNextParallel.DaggerNestedAlgorithm(f::Function, iterable; workers = workers(), kwargs...)
13+
return DaggerNestedAlgorithm(; algorithms = map(f, iterable), workers, kwargs...)
14+
end
15+
16+
function initialize_dagger_state(
17+
problem::AIE.Problem,
18+
algorithm::AIE.Algorithm;
19+
iterate,
20+
remote_subiterates = Dict{Int, Dagger.Chunk}(),
21+
)
22+
23+
stopping_criterion_state = AI.initialize_state(
24+
problem, algorithm, algorithm.stopping_criterion
25+
)
26+
27+
remote_results = Dict{Int, Dagger.DTask}()
28+
29+
return DaggerState(; iterate, remote_subiterates, stopping_criterion_state, remote_results)
30+
end
31+
32+
function AI.initialize_state(
33+
problem::AIE.Problem,
34+
algorithm::DaggerNestedAlgorithm;
35+
kwargs...
36+
)
37+
return initialize_dagger_state(problem, algorithm; kwargs...)
38+
end
39+
40+
function AIE.get_subproblem(
41+
problem::AIE.Problem,
42+
algorithm::AIE.NestedAlgorithm,
43+
state::DaggerState
44+
)
45+
subproblem = problem
46+
subalgorithm = algorithm.algorithms[state.iteration]
47+
48+
iterate = state.iterate
49+
remote_subiterates = state.remote_subiterates
50+
51+
substate = AI.initialize_state(subproblem, subalgorithm; iterate, remote_subiterates)
52+
53+
return subproblem, subalgorithm, substate
54+
end
55+
56+
57+
function AI.step!(
58+
problem::AI.Problem,
59+
algorithm::DaggerNestedAlgorithm,
60+
state::DaggerState;
61+
kwargs...
62+
)
63+
64+
subproblem, subalgorithm, subiterate_chunk = AIE.get_subproblem(problem, algorithm, state)
65+
66+
dtask = Dagger.@spawn AI.solve(subproblem, subalgorithm; iterate = subiterate_chunk)
67+
68+
AIE.set_substate!(problem, algorithm, state, dtask)
69+
70+
return state
71+
end
72+
73+
function AIE.set_substate!(
74+
::AIE.Problem,
75+
::DaggerNestedAlgorithm,
76+
state::DaggerState,
77+
dtask::Dagger.DTask,
78+
)
79+
state.remote_results[state.iteration] = dtask
80+
81+
return state
82+
end
83+
84+
include("daggerbeliefpropagation.jl")
85+
86+
end # ITensorNetworksNextDaggerExt
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
using Dagger
2+
using Dagger.Distributed
3+
4+
using DataGraphs: DataGraphs, get_edge_data, get_vertex_data, is_edge_assigned,
5+
is_vertex_assigned, set_edge_data!, set_vertex_data!, underlying_graph
6+
using Dictionaries: Indices
7+
using Graphs: AbstractEdge, AbstractGraph, dst, edges, src, vertices
8+
using ITensorNetworksNext.ITensorNetworksNextParallel: DaggerBeliefPropagationCache,
9+
DaggerNestedAlgorithm, DaggerState, ITensorNetworksNextParallel, dagger_algorithm,
10+
subcache
11+
using ITensorNetworksNext: BeliefPropagation, BeliefPropagationCache,
12+
BeliefPropagationProblem, BeliefPropagationState, ITensorNetworksNext,
13+
beliefpropagation, forest_cover_edge_sequence, select_algorithm
14+
using NamedGraphs.PartitionedGraphs: QuotientVertex, quotientedges, quotientvertices
15+
using NamedGraphs: NamedGraphs
16+
using NamedGraphs.GraphsExtensions: boundary_edges
17+
18+
function ITensorNetworksNextParallel.subcache(cache::DaggerBeliefPropagationCache, inds)
19+
return subcache(cache.underlying_cache, inds)
20+
end
21+
22+
function ITensorNetworksNextParallel.DaggerBeliefPropagationCache(network::AbstractGraph)
23+
underlying_cache = BeliefPropagationCache(network)
24+
25+
keys = Indices(quotientvertices(underlying_cache))
26+
27+
workers = Iterators.cycle(Distributed.workers())
28+
worker_dict = similar(keys, Int)
29+
30+
for quotient_vertex in keys
31+
worker, workers = Iterators.peel(workers)
32+
worker_dict[quotient_vertex] = worker
33+
end
34+
35+
quotient_chunks = map(keys) do quotient_vertex
36+
worker = worker_dict[quotient_vertex]
37+
iterate = subcache(underlying_cache, quotient_vertex)
38+
chunk = Dagger.@mutable worker = worker BeliefPropagationState(; iterate)
39+
return chunk
40+
end
41+
42+
return DaggerBeliefPropagationCache(underlying_cache, quotient_chunks)
43+
end
44+
45+
DataGraphs.underlying_graph(cache::DaggerBeliefPropagationCache) = underlying_graph(cache.underlying_cache)
46+
47+
DataGraphs.is_vertex_assigned(bpc::DaggerBeliefPropagationCache, vertex) = is_vertex_assigned(bpc.underlying_cache, vertex)
48+
DataGraphs.is_edge_assigned(bpc::DaggerBeliefPropagationCache, edge) = is_edge_assigned(bpc.undelying_cache, edge)
49+
50+
DataGraphs.get_vertex_data(bpc::DaggerBeliefPropagationCache, vertex) = get_vertex_data(bpc.underlying_cache, vertex)
51+
DataGraphs.get_edge_data(bpc::DaggerBeliefPropagationCache, edge::AbstractEdge) = get_edge_data(bpc.undelying_caches, edge)
52+
53+
DataGraphs.set_vertex_data!(bpc::DaggerBeliefPropagationCache, val, vertex) = set_vertex_data!(bpc.underlying_cache, val, vertex)
54+
DataGraphs.set_edge_data!(bpc::DaggerBeliefPropagationCache, val, edge) = set_edge_data!(bpc.underlying_cache, val, edge)
55+
56+
NamedGraphs.to_graph_index(::DaggerBeliefPropagationCache, qv::QuotientVertex) = qv
57+
function DataGraphs.get_index_data(cache::DaggerBeliefPropagationCache, qv::QuotientVertex)
58+
return cache.quotient_chunks[qv]
59+
end
60+
61+
function ITensorNetworksNext.beliefpropagation_sweep(cache::DaggerBeliefPropagationCache; edges, workers = workers(), kwargs...)
62+
63+
keys = collect(quotientvertices(cache))
64+
65+
return dagger_algorithm(keys; keys, workers) do quotient_vertex
66+
67+
subcache = fetch(cache[quotient_vertex]).iterate
68+
69+
subcache_edges = forest_cover_edge_sequence(subcache) edges
70+
incoming_edges = boundary_edges(cache, vertices(cache, quotient_vertex); dir = :in)
71+
72+
alg = select_algorithm(
73+
beliefpropagation,
74+
subcache;
75+
# Don't update the incoming messages
76+
edges = setdiff(subcache_edges, incoming_edges),
77+
maxiter = 1,
78+
kwargs...
79+
)
80+
81+
return alg
82+
end
83+
end
84+
85+
function AI.initialize_state(
86+
problem::AIE.Problem,
87+
algorithm::BeliefPropagation{<:DaggerNestedAlgorithm};
88+
kwargs...
89+
)
90+
return initialize_dagger_state(problem, algorithm; kwargs...)
91+
end
92+
93+
function AIE.get_subproblem(
94+
problem::BeliefPropagationProblem,
95+
algorithm::DaggerNestedAlgorithm,
96+
state::DaggerState,
97+
)
98+
subproblem = problem
99+
subalgorithm = algorithm.algorithms[state.iteration]
100+
101+
quotient_vertex = algorithm.keys[state.iteration]
102+
103+
cache = state.iterate.iterate
104+
105+
subiterate = cache[quotient_vertex]
106+
107+
return subproblem, subalgorithm, subiterate
108+
end
109+
110+
function AIE.set_substate!(
111+
::BeliefPropagationProblem,
112+
algorithm::AIE.NestedAlgorithm,
113+
state::AIE.State,
114+
substate::DaggerState,
115+
)
116+
117+
dst_cache = state.iterate.iterate
118+
119+
state.iterate.maxdiff = 0.0
120+
121+
current_algorithm = algorithm.algorithms[state.iteration]
122+
123+
for (i, quotient_vertex) in enumerate(current_algorithm.keys)
124+
get_maxdiff = dtask -> dtask.iterate.maxdiff
125+
src_maxdiff = fetch(Dagger.@spawn get_maxdiff(substate.remote_results[i]))
126+
127+
if src_maxdiff > state.iterate.maxdiff
128+
state.iterate.maxdiff = src_maxdiff
129+
end
130+
end
131+
132+
133+
transfer_edges! = (dst_chunk, src_chunk, edges) -> begin
134+
src_subcache = src_chunk.iterate
135+
dst_subcache = dst_chunk.iterate
136+
for edge in edges
137+
dst_subcache[edge] = src_subcache[edge]
138+
end
139+
end
140+
141+
transfer_dtasks = map(quotientedges(dst_cache)) do quotient_edge
142+
src_subcache = dst_cache[src(quotient_edge)]
143+
dst_subcache = dst_cache[dst(quotient_edge)]
144+
return Dagger.@spawn transfer_edges!(dst_subcache, fetch(src_subcache), edges(dst_cache, quotient_edge))
145+
end
146+
147+
wait.(transfer_dtasks)
148+
149+
return state
150+
end
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module ITensorNetworksNextDistributedExt
2+
3+
using Distributed
4+
5+
import AlgorithmsInterface as AI
6+
import ITensorNetworksNext.AlgorithmsInterfaceExtensions as AIE
7+
8+
import ITensorNetworksNext.ITensorNetworksNextParallel as Parallel
9+
10+
function initialize_distributed_state(
11+
problem::AIE.Problem,
12+
algorithm::AIE.Algorithm;
13+
keys,
14+
iterate,
15+
kwargs...
16+
)
17+
stopping_criterion_state = AI.initialize_state(
18+
problem, algorithm, algorithm.stopping_criterion
19+
)
20+
remote_results = Dict{eltype(keys), Distributed.Future}()
21+
22+
return Parallel.DistributedState(; iterate, stopping_criterion_state, remote_results)
23+
end
24+
25+
function AI.initialize_state(
26+
problem::AIE.Problem,
27+
algorithm::Parallel.DistributedNestedAlgorithm;
28+
kwargs...
29+
)
30+
return initialize_distributed_state(problem, algorithm; keys = algorithm.keys, kwargs...)
31+
end
32+
33+
function Parallel.DistributedNestedAlgorithm(f::Function, iterable; kwargs...)
34+
return Parallel.DistributedNestedAlgorithm(; algorithms = map(f, iterable), kwargs...)
35+
end
36+
37+
function AIE.get_subproblem(
38+
problem::AI.Problem, algorithm::Parallel.DistributedNestedAlgorithm, state::Parallel.DistributedState
39+
)
40+
subproblem = problem
41+
subalgorithm = algorithm.algorithms[state.iteration]
42+
43+
return subproblem, subalgorithm, state.iterate
44+
end
45+
46+
function AI.step!(
47+
problem::AI.Problem,
48+
algorithm::Parallel.DistributedNestedAlgorithm,
49+
state::Parallel.DistributedState;
50+
kwargs...
51+
)
52+
53+
subproblem, subalgorithm, subiterate = AIE.get_subproblem(problem, algorithm, state)
54+
55+
# Do whatever should have happened at `step!`, but store the result as a future.
56+
57+
function solve(subproblem, subalgorithm, iterate)
58+
rv = AI.solve(subproblem, subalgorithm; iterate)
59+
return rv
60+
end
61+
62+
future = remotecall(solve, algorithm.workers, subproblem, subalgorithm, subiterate)
63+
64+
AIE.set_substate!(problem, algorithm, state, future)
65+
66+
return state
67+
end
68+
69+
function AIE.set_substate!(
70+
::AIE.Problem,
71+
algorithm::Parallel.DistributedNestedAlgorithm,
72+
state::Parallel.DistributedState,
73+
future::Distributed.Future,
74+
)
75+
key = algorithm.keys[state.iteration]
76+
77+
state.remote_results[key] = future
78+
79+
return state
80+
end
81+
82+
include("distributedbeliefpropagation.jl")
83+
84+
end # ITensorNetworksNextDistributedExt

0 commit comments

Comments
 (0)