Skip to content

Commit b531ff0

Browse files
committed
[Heuristics] Add steal rejection tracking metric and unit test
1 parent dce89e9 commit b531ff0

2 files changed

Lines changed: 47 additions & 2 deletions

File tree

distributed/stealing.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(self, scheduler: Scheduler):
118118
self.metrics = {
119119
"request_count_total": defaultdict(int),
120120
"request_cost_total": defaultdict(int),
121+
"reject_count_margin_total": defaultdict(int),
121122
}
122123
self._request_counter = 0
123124
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
@@ -489,10 +490,17 @@ def balance(self) -> None:
489490

490491
# Require at least 50% ROI on the network transfer cost to prevent thrashing
491492
margin = comm_cost_thief * 0.5
492-
if (
493+
494+
would_steal_without_margin = (
495+
occ_thief + comm_cost_thief + compute
496+
<= occ_victim - (comm_cost_victim + compute) / 2
497+
)
498+
would_steal_with_margin = (
493499
occ_thief + comm_cost_thief + compute + margin
494500
<= occ_victim - (comm_cost_victim + compute) / 2
495-
):
501+
)
502+
503+
if would_steal_with_margin:
496504
self.move_task_request(ts, victim, thief)
497505
cost = compute + comm_cost_victim
498506
log.append(
@@ -523,6 +531,13 @@ def balance(self) -> None:
523531
# for removing ts from stealable. If we made sure to
524532
# properly clean up, we would not need this
525533
stealable.discard(ts)
534+
elif would_steal_without_margin:
535+
self.metrics["reject_count_margin_total"][level] += 1
536+
logger.debug(
537+
"Work-stealing margin heuristic rejected steal of task %s "
538+
"(thief=%s, victim=%s, level=%d, margin=%.4f)",
539+
ts.key, thief.address, victim.address, level, margin,
540+
)
526541
self.scheduler.check_idle_saturated(
527542
victim, occ=combined_occupancy(victim)
528543
)

distributed/tests/test_steal.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,36 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
20102010
await block_event.set()
20112011

20122012

2013+
@gen_cluster(
2014+
client=True,
2015+
nthreads=[("127.0.0.1", 1)] * 2,
2016+
config={"distributed.scheduler.work-stealing-interval": "100ms", **NO_AMM},
2017+
)
2018+
async def test_reject_count_margin_metric(c, s, a, b):
2019+
"""
2020+
Verify that the margin heuristic increments reject_count_margin_total
2021+
when a steal is suppressed that old logic would have permitted.
2022+
"""
2023+
steal = s.extensions["stealing"]
2024+
await steal.stop()
2025+
2026+
# Generate large data on worker A to ensure high network transfer cost
2027+
[x] = await c.scatter([b"0" * 50_000_000], workers=a.address)
2028+
2029+
# Create tasks on A to saturate it and trigger stealing evaluation
2030+
futures = [
2031+
c.submit(slowidentity, x, pure=False, delay=0.01, workers=a.address, allow_other_workers=True)
2032+
for _ in range(10)
2033+
]
2034+
2035+
while len(a.state.tasks) < 10:
2036+
await asyncio.sleep(0.01)
2037+
2038+
# Balance will evaluate the cost. High comm_cost, low compute.
2039+
# Without margin, it would steal. With 50% ROI margin, it should reject.
2040+
steal.balance()
2041+
2042+
assert sum(steal.metrics["reject_count_margin_total"].values()) >= 1
20132043
@gen_cluster(
20142044
nthreads=[("", 1)],
20152045
client=True,

0 commit comments

Comments
 (0)