@@ -532,6 +532,25 @@ struct ScheduleTaskSpec
532532 est_alloc_util:: UInt64
533533 est_occupancy:: UInt32
534534end
535+
536+ " Ordering key for task locations when using MPI acceleration (deterministic across ranks)."
537+ function _mpi_fire_order_key (loc:: ScheduleTaskLocation )
538+ g = loc. gproc
539+ p = loc. proc
540+ g_rank = g isa Union{Dagger. MPIOSProc, Dagger. MPIProcessor} ? g. rank : root_worker_id (g)
541+ p_rank = p isa Union{Dagger. MPIOSProc, Dagger. MPIProcessor} ? p. rank : root_worker_id (p)
542+ return (g_rank, p_rank)
543+ end
544+
545+ " Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)."
546+ function _mpi_proc_rank (proc:: Processor )
547+ g = get_parent (proc)
548+ p = proc
549+ g_rank = g isa Union{Dagger. MPIOSProc, Dagger. MPIProcessor} ? g. rank : root_worker_id (g)
550+ p_rank = p isa Union{Dagger. MPIOSProc, Dagger. MPIProcessor} ? p. rank : root_worker_id (p)
551+ return (g_rank, p_rank)
552+ end
553+
535554@reuse_scope function schedule! (ctx, state, sch_options, procs= procs_to_use (ctx, sch_options))
536555 lock (state. lock) do
537556 safepoint (state)
@@ -688,14 +707,17 @@ end
688707 # Fire all newly-scheduled tasks (owner/local first, then by fire_order_key to avoid MPI execute! deadlock)
689708 @label fire_tasks
690709 task_locs = collect (keys (to_fire))
710+ if Dagger. current_acceleration () isa Dagger. MPIAcceleration
711+ sort! (task_locs, by= _mpi_fire_order_key)
712+ end
691713 rank = try
692714 M = parentmodule (@__MODULE__ )
693715 (isdefined (M, :MPI ) && M. MPI. Initialized ()) ? Int (M. MPI. Comm_rank (M. MPI. COMM_WORLD)) : nothing
694716 catch
695717 nothing
696718 end
697- Core . println ( " fire order rank= " , rank, " task_locs= " , task_locs)
698- for task_loc in task_locs
719+ for (i, task_loc) in enumerate ( task_locs)
720+ # Core.println("fire_order rank=", rank, " [", i, "/", length( task_locs), "] task_loc=", task_loc)
699721 fire_tasks! (ctx, task_loc, to_fire[task_loc], state)
700722 end
701723 to_fire_cleanup ()
@@ -1141,12 +1163,15 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
11411163 # Try to steal a task
11421164 @maybelog ctx timespan_start (ctx, :proc_steal_local , (;uid, worker= wid, processor= to_proc), nothing )
11431165
1144- # Try to steal from local queues randomly
1166+ # Try to steal from local queues randomly (deterministic order when MPI to avoid deadlocks)
11451167 # TODO : Prioritize stealing from busiest processors
11461168 states = proc_states_values (uid)
1147- # TODO : Try to pre-allocate this
1148- P = randperm (length (states))
1149- for state in getindex .(Ref (states), P)
1169+ order = if Dagger. current_acceleration () isa Dagger. MPIAcceleration
1170+ sort (1 : length (states), by= i-> _mpi_proc_rank (states[i]. state. proc))
1171+ else
1172+ randperm (length (states))
1173+ end
1174+ for state in getindex .(Ref (states), order)
11501175 other_istate = state. state
11511176 if other_istate. proc === to_proc
11521177 continue
@@ -1355,11 +1380,15 @@ function do_tasks(to_proc, return_queue, tasks)
13551380 end
13561381 notify (istate. reschedule)
13571382
1358- # Kick other processors to make them steal
1383+ # Kick other processors to make them steal (deterministic order when MPI to avoid deadlocks)
13591384 # TODO : Alternatively, automatically balance work instead of blindly enqueueing
13601385 states = proc_states_values (uid)
1361- P = randperm (length (states))
1362- for other_state in getindex .(Ref (states), P)
1386+ order = if Dagger. current_acceleration () isa Dagger. MPIAcceleration
1387+ sort (1 : length (states), by= i-> _mpi_proc_rank (states[i]. state. proc))
1388+ else
1389+ randperm (length (states))
1390+ end
1391+ for other_state in getindex .(Ref (states), order)
13631392 other_istate = other_state. state
13641393 if other_istate. proc === to_proc
13651394 continue
@@ -1477,11 +1506,13 @@ Executes a single task specified by `task` on `to_proc`.
14771506 #= FIXME : This isn't valid if x is written to
14781507 x = if x isa Chunk
14791508 value = lock(TASK_SYNC) do
1480- if haskey(CHUNK_CACHE, x)
1481- Some{Any}(get!(CHUNK_CACHE[x], to_proc) do
1482- # Convert from cached value
1483- # TODO : Choose "closest" processor of same type first
1484- some_proc = first(keys(CHUNK_CACHE[x]))
1509+ if haskey(CHUNK_CACHE, x)
1510+ Some{Any}(get!(CHUNK_CACHE[x], to_proc) do
1511+ # Convert from cached value
1512+ # TODO : Choose "closest" processor of same type first
1513+ cache_procs = keys(CHUNK_CACHE[x])
1514+ some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ?
1515+ minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs)
14851516 some_x = CHUNK_CACHE[x][some_proc]
14861517 @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x"
14871518 @invokelatest move(some_proc, to_proc, some_x)
0 commit comments