Skip to content

Commit bc2f7d7

Browse files
committed
MPI: Fixing non-uniformity in dict-key iteration
1 parent f8f5756 commit bc2f7d7

6 files changed

Lines changed: 71 additions & 23 deletions

File tree

src/datadeps/aliasing.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper,
463463
if !haskey(state.ainfos_owner, target_ainfo)
464464
overlaps = Set{AliasingWrapper}()
465465
push!(overlaps, target_ainfo)
466-
for other_ainfo in keys(state.ainfos_owner)
466+
other_ainfos = (Dagger.current_acceleration() isa Dagger.MPIAcceleration
467+
? sort(collect(keys(state.ainfos_owner)), by=hash)
468+
: keys(state.ainfos_owner))
469+
for other_ainfo in other_ainfos
467470
target_ainfo == other_ainfo && continue
468471
if will_alias(target_ainfo, other_ainfo)
469472
# Mark us and them as overlapping

src/datadeps/queue.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr
332332
sig = Sch.signature(sch_state, f, map(first, chunks_locality))
333333
task_pressure = get(sch_state.signature_time_cost, sig, 1000^3)
334334

335-
# Shuffle procs around, so equally-costly procs are equally considered
336-
P = randperm(length(all_procs))
337-
procs = getindex.(Ref(all_procs), P)
335+
# Shuffle procs around, so equally-costly procs are equally considered (skip when MPI for deterministic tie-breaking)
336+
procs = if current_acceleration() isa Dagger.MPIAcceleration
337+
collect(all_procs)
338+
else
339+
P = randperm(length(all_procs))
340+
getindex.(Ref(all_procs), P)
341+
end
338342

339343
# Sort by lowest cost first
340344
sort!(procs, by=p->costs[p])
@@ -397,7 +401,11 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr
397401
delete!(spaces_completed, our_space)
398402
continue
399403
end
400-
our_proc = rand(our_space_procs)
404+
our_proc = if current_acceleration() isa Dagger.MPIAcceleration
405+
first(sort(collect(our_space_procs), by=short_name))
406+
else
407+
rand(our_space_procs)
408+
end
401409
break
402410
end
403411

src/datadeps/scheduling.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ function datadeps_schedule_task(sched::UltraScheduler, state::DataDepsState, all
111111
delete!(spaces_completed, our_space)
112112
continue
113113
end
114-
our_proc = rand(our_space_procs)
114+
our_proc = if Dagger.current_acceleration() isa Dagger.MPIAcceleration
115+
first(sort(collect(our_space_procs), by=Dagger.short_name))
116+
else
117+
rand(our_space_procs)
118+
end
115119
break
116120
end
117121

src/mpi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ end
367367

368368
const DEADLOCK_DETECT = TaskLocalValue{Bool}(()->true)
369369
const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0)
370-
const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->60.0)
370+
const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->600.0)
371371
const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}())
372372

373373
struct InplaceInfo

src/sch/Sch.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,25 @@ struct ScheduleTaskSpec
532532
est_alloc_util::UInt64
533533
est_occupancy::UInt32
534534
end
535+
536+
"Ordering key for task locations when using MPI acceleration (deterministic across ranks)."
537+
function _mpi_fire_order_key(loc::ScheduleTaskLocation)
538+
g = loc.gproc
539+
p = loc.proc
540+
g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g)
541+
p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p)
542+
return (g_rank, p_rank)
543+
end
544+
545+
"Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)."
546+
function _mpi_proc_rank(proc::Processor)
547+
g = get_parent(proc)
548+
p = proc
549+
g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g)
550+
p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p)
551+
return (g_rank, p_rank)
552+
end
553+
535554
@reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options))
536555
lock(state.lock) do
537556
safepoint(state)
@@ -688,14 +707,17 @@ end
688707
# Fire all newly-scheduled tasks (owner/local first, then by fire_order_key to avoid MPI execute! deadlock)
689708
@label fire_tasks
690709
task_locs = collect(keys(to_fire))
710+
if Dagger.current_acceleration() isa Dagger.MPIAcceleration
711+
sort!(task_locs, by=_mpi_fire_order_key)
712+
end
691713
rank = try
692714
M = parentmodule(@__MODULE__)
693715
(isdefined(M, :MPI) && M.MPI.Initialized()) ? Int(M.MPI.Comm_rank(M.MPI.COMM_WORLD)) : nothing
694716
catch
695717
nothing
696718
end
697-
Core.println("fire order rank=", rank, " task_locs=", task_locs)
698-
for task_loc in task_locs
719+
for (i, task_loc) in enumerate(task_locs)
720+
#Core.println("fire_order rank=", rank, " [", i, "/", length(task_locs), "] task_loc=", task_loc)
699721
fire_tasks!(ctx, task_loc, to_fire[task_loc], state)
700722
end
701723
to_fire_cleanup()
@@ -1141,12 +1163,15 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11411163
# Try to steal a task
11421164
@maybelog ctx timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing)
11431165

1144-
# Try to steal from local queues randomly
1166+
# Try to steal from local queues randomly (deterministic order when MPI to avoid deadlocks)
11451167
# TODO: Prioritize stealing from busiest processors
11461168
states = proc_states_values(uid)
1147-
# TODO: Try to pre-allocate this
1148-
P = randperm(length(states))
1149-
for state in getindex.(Ref(states), P)
1169+
order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration
1170+
sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc))
1171+
else
1172+
randperm(length(states))
1173+
end
1174+
for state in getindex.(Ref(states), order)
11501175
other_istate = state.state
11511176
if other_istate.proc === to_proc
11521177
continue
@@ -1355,11 +1380,15 @@ function do_tasks(to_proc, return_queue, tasks)
13551380
end
13561381
notify(istate.reschedule)
13571382

1358-
# Kick other processors to make them steal
1383+
# Kick other processors to make them steal (deterministic order when MPI to avoid deadlocks)
13591384
# TODO: Alternatively, automatically balance work instead of blindly enqueueing
13601385
states = proc_states_values(uid)
1361-
P = randperm(length(states))
1362-
for other_state in getindex.(Ref(states), P)
1386+
order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration
1387+
sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc))
1388+
else
1389+
randperm(length(states))
1390+
end
1391+
for other_state in getindex.(Ref(states), order)
13631392
other_istate = other_state.state
13641393
if other_istate.proc === to_proc
13651394
continue
@@ -1477,11 +1506,13 @@ Executes a single task specified by `task` on `to_proc`.
14771506
#= FIXME: This isn't valid if x is written to
14781507
x = if x isa Chunk
14791508
value = lock(TASK_SYNC) do
1480-
if haskey(CHUNK_CACHE, x)
1481-
Some{Any}(get!(CHUNK_CACHE[x], to_proc) do
1482-
# Convert from cached value
1483-
# TODO: Choose "closest" processor of same type first
1484-
some_proc = first(keys(CHUNK_CACHE[x]))
1509+
if haskey(CHUNK_CACHE, x)
1510+
Some{Any}(get!(CHUNK_CACHE[x], to_proc) do
1511+
# Convert from cached value
1512+
# TODO: Choose "closest" processor of same type first
1513+
cache_procs = keys(CHUNK_CACHE[x])
1514+
some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ?
1515+
minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs)
14851516
some_x = CHUNK_CACHE[x][some_proc]
14861517
@dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x"
14871518
@invokelatest move(some_proc, to_proc, some_x)

src/sch/util.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,14 @@ end
590590
end
591591
chunks_cleanup()
592592

593-
# Shuffle procs around, so equally-costly procs are equally considered
593+
# Shuffle procs around, so equally-costly procs are equally considered (skip shuffle when MPI for deterministic tie-breaking)
594594
np = length(procs)
595595
@reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin
596596
resize!(P, np)
597597
copyto!(P, 1:np)
598-
randperm!(P)
598+
if !(Dagger.current_acceleration() isa Dagger.MPIAcceleration)
599+
randperm!(P)
600+
end
599601
for idx in 1:np
600602
sorted_procs[idx] = procs[P[idx]]
601603
end

0 commit comments

Comments
 (0)