Skip to content

Commit 988a92f

Browse files
authored
Make transfer rate metric per-processor in scheduler (#697)
1 parent 555ae66 commit 988a92f

3 files changed

Lines changed: 35 additions & 7 deletions

File tree

src/sch/Sch.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Fields:
6767
- `worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}` - Communication channels between the scheduler and each worker
6868
- `signature_time_cost::Dict{Signature,UInt64}` - Cache of estimated CPU time (in nanoseconds) required to compute calls with the given signature
6969
- `signature_alloc_cost::Dict{Signature,UInt64}` - Cache of estimated CPU RAM (in bytes) required to compute calls with the given signature
70-
- `transfer_rate::Ref{UInt64}` - Estimate of the network transfer rate in bytes per second
70+
- `worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}}` - Maps from worker ID to per-processor network transfer rate estimates in bytes per second
7171
- `halt::Base.Event` - Event indicating that the scheduler is halting
7272
- `lock::ReentrantLock` - Lock around operations which modify the state
7373
- `futures::Dict{Thunk, Vector{ThunkFuture}}` - Futures registered for waiting on the result of a thunk.
@@ -93,7 +93,7 @@ struct ComputeState
9393
worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}}
9494
signature_time_cost::Dict{Signature,UInt64}
9595
signature_alloc_cost::Dict{Signature,UInt64}
96-
transfer_rate::Ref{UInt64}
96+
worker_transfer_rate::Dict{Int,Dict{Processor,UInt64}}
9797
halt::Base.Event
9898
lock::ReentrantLock
9999
futures::Dict{Thunk, Vector{ThunkFuture}}
@@ -122,7 +122,7 @@ function start_state(deps::Dict, node_order, chan)
122122
Dict{Int, Tuple{RemoteChannel,RemoteChannel}}(),
123123
Dict{Signature,UInt64}(),
124124
Dict{Signature,UInt64}(),
125-
Ref{UInt64}(1_000_000),
125+
Dict{Int,Dict{Processor,UInt64}}(),
126126
Base.Event(),
127127
ReentrantLock(),
128128
Dict{Thunk, Vector{ThunkFuture}}(),
@@ -157,6 +157,7 @@ function init_proc(state, p, log_sink)
157157
gproc = OSProc(p.pid)
158158
lock(state.lock) do
159159
state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}()
160+
state.worker_transfer_rate[p.pid] = Dict{Processor,UInt64}()
160161

161162
state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}()
162163
state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}()
@@ -430,7 +431,12 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt
430431
state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2
431432
state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2
432433
if metadata.transfer_rate !== nothing
433-
state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ 2
434+
old_rate = get(state.worker_transfer_rate[pid], proc, UInt64(0))
435+
if old_rate == 0
436+
state.worker_transfer_rate[pid][proc] = metadata.transfer_rate
437+
else
438+
state.worker_transfer_rate[pid][proc] = (old_rate + metadata.transfer_rate) ÷ 2
439+
end
434440
end
435441
end
436442
if res isa Chunk
@@ -736,6 +742,7 @@ function remove_dead_proc!(ctx, state, proc, options)
736742
@assert options.single !== proc.pid "Single worker failed, cannot continue."
737743
rmprocs!(ctx, [proc])
738744
delete!(state.worker_time_pressure, proc.pid)
745+
delete!(state.worker_transfer_rate, proc.pid)
739746
delete!(state.worker_storage_pressure, proc.pid)
740747
delete!(state.worker_storage_capacity, proc.pid)
741748
delete!(state.worker_loadavg, proc.pid)

src/sch/util.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,8 @@ function estimate_task_costs(state, procs, task; sig=nothing)
561561
estimate_task_costs!(sorted_procs, costs, state, procs, task; sig)
562562
return sorted_procs, costs
563563
end
564+
const DEFAULT_TRANSFER_RATE = UInt64(1_000_000)
564565
@reuse_scope function estimate_task_costs!(sorted_procs, costs, state, procs, task; sig=nothing)
565-
tx_rate = state.transfer_rate[]
566566

567567
# Find all Chunks
568568
chunks = @reusable_vector :estimate_task_costs_chunks Union{Chunk,Nothing} nothing 32
@@ -595,7 +595,7 @@ end
595595
# TODO: Actually estimate/benchmark this
596596
task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms
597597

598-
# Compute final cost
598+
tx_rate = get(get(state.worker_transfer_rate, gproc.pid, Dict{Processor,UInt64}()), proc, DEFAULT_TRANSFER_RATE)
599599
costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost
600600
end
601601
chunks_cleanup()

test/scheduler.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ end
409409

410410
#pres1_1 = state.worker_time_pressure[1][tproc1_1]
411411
#pres2_1 = state.worker_time_pressure[first(workers())][tproc2_1]
412-
tx_rate = state.transfer_rate[]
412+
tx_rate = get(get(state.worker_transfer_rate, first(workers()), Dict{Dagger.Processor,UInt64}()), tproc2_1, Dagger.Sch.DEFAULT_TRANSFER_RATE)
413413
tx_xfer_cost = 1e6
414414
sig_unknown_cost = 1e9
415415

@@ -452,6 +452,27 @@ end
452452
@test costs[tproc2_1] (tx_size/tx_rate) + tx_xfer_cost + #=pres2_1 +=# sig_unknown_cost # All chunks are remote, and this signature is unknown
453453
end
454454
end
455+
456+
@testset "Per-Processor Transfer Rate" begin
457+
wid = first(workers())
458+
459+
state.worker_transfer_rate[1][tproc1_1] = UInt64(2_000_000)
460+
state.worker_transfer_rate[wid][tproc2_1] = UInt64(500_000)
461+
462+
args = [Dagger.tochunk(1), Dagger.tochunk(2)]
463+
tx_size = 2 * sizeof(Int)
464+
t = delayed(mynothing)(args...)
465+
Dagger.Sch.collect_task_inputs!(state, t)
466+
_, costs = Dagger.Sch.estimate_task_costs(state, procs, t)
467+
468+
if nprocs() > 1
469+
@test costs[tproc2_1] sig_unknown_cost + (tx_size / 500_000) + tx_xfer_cost
470+
end
471+
@test costs[tproc1_1] sig_unknown_cost
472+
473+
delete!(state.worker_transfer_rate[1], tproc1_1)
474+
delete!(state.worker_transfer_rate[wid], tproc2_1)
475+
end
455476
end
456477
end
457478

0 commit comments

Comments
 (0)