Skip to content

Commit f121bba

Browse files
authored
fix: speed up scheduler queue views (#728)
1 parent 864322b commit f121bba

2 files changed

Lines changed: 138 additions & 30 deletions

File tree

  • packages/data-designer-engine
    • src/data_designer/engine/dataset_builders/scheduling
    • tests/engine/dataset_builders/scheduling

packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def __init__(self) -> None:
4646
self._queued: dict[str, SchedulableTask] = {}
4747
self._task_groups: dict[str, TaskGroupKey] = {}
4848
self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {}
49+
self._queued_by_group: Counter[TaskGroupKey] = Counter()
50+
self._queued_resource_demand_by_group: dict[TaskGroupKey, Counter[SchedulerResourceKey]] = defaultdict(Counter)
51+
self._queued_peer_demand_by_resource: Counter[SchedulerResourceKey] = Counter()
4952
self._group_finish: dict[TaskGroupKey, float] = {}
5053
self._heap: list[tuple[float, int, TaskGroupKey]] = []
5154
self._active_heap_keys: set[TaskGroupKey] = set()
@@ -69,6 +72,7 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]:
6972
queue.append(item)
7073
self._queued[item.task_id] = item
7174
self._task_groups[item.task_id] = item.group.key
75+
self._increment_queue_accounting(item)
7276
self._activate_group(item.group.key)
7377
accepted.append(item.task_id)
7478
if accepted:
@@ -77,10 +81,8 @@ def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]:
7781

7882
def discard(self, task_id: str) -> None:
7983
"""Remove a queued task lazily if it is no longer dispatchable."""
80-
if task_id in self._queued:
84+
if self._remove_queued_item(task_id) is not None:
8185
self._sequence_version += 1
82-
self._queued.pop(task_id, None)
83-
self._task_groups.pop(task_id, None)
8486

8587
def discard_where(self, predicate: Callable[[SchedulableTask], bool]) -> None:
8688
"""Remove queued tasks matching a predicate."""
@@ -125,8 +127,7 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None:
125127
return None
126128

127129
queue.popleft()
128-
self._queued.pop(item.task_id, None)
129-
self._task_groups.pop(item.task_id, None)
130+
self._remove_queued_item(item.task_id)
130131
self._active_heap_keys.discard(key)
131132
self._active_heap_entries.pop(key, None)
132133
group = self._group_specs[key]
@@ -140,35 +141,28 @@ def commit(self, selection: QueueSelection) -> SchedulableTask | None:
140141
return item
141142

142143
def view(self) -> QueueView:
143-
queued_by_group: Counter[TaskGroupKey] = Counter()
144-
demand_by_group: dict[TaskGroupKey, dict[SchedulerResourceKey, int]] = defaultdict(lambda: defaultdict(int))
145144
first_by_group: dict[TaskGroupKey, Mapping[SchedulerResourceKey, int]] = {}
146145
first_tasks_by_group: dict[TaskGroupKey, SchedulableTask] = {}
147146
first_group_specs: dict[TaskGroupKey, TaskGroupSpec] = {}
148-
demand_by_resource: Counter[SchedulerResourceKey] = Counter()
149147

150-
for item in self._queued.values():
151-
key = item.group.key
152-
queued_by_group[key] += 1
153-
for resource, amount in item.resource_request.amounts.items():
154-
demand_by_group[key][resource] += amount
155-
demand_by_resource[resource] += amount
156-
157-
for key, queue in self._queues.items():
148+
for key in self._queued_by_group:
158149
first = self._first_valid_item(key)
159-
if first is not None:
160-
first_by_group[key] = dict(first.resource_request.amounts)
161-
first_tasks_by_group[key] = first
162-
first_group_specs[key] = first.group
150+
if first is None:
151+
continue
152+
first_by_group[key] = dict(first.resource_request.amounts)
153+
first_tasks_by_group[key] = first
154+
first_group_specs[key] = first.group
163155

164156
return QueueView(
165157
queued_total=len(self._queued),
166-
queued_by_group=dict(queued_by_group),
167-
queued_resource_demand_by_group={key: dict(value) for key, value in demand_by_group.items()},
158+
queued_by_group=dict(self._queued_by_group),
159+
queued_resource_demand_by_group={
160+
key: dict(value) for key, value in self._queued_resource_demand_by_group.items()
161+
},
168162
first_candidate_resources_by_group=first_by_group,
169163
first_candidate_tasks_by_group=first_tasks_by_group,
170164
first_candidate_group_specs_by_group=first_group_specs,
171-
queued_peer_demand_by_resource=dict(demand_by_resource),
165+
queued_peer_demand_by_resource=dict(self._queued_peer_demand_by_resource),
172166
)
173167

174168
def _activate_group(self, key: TaskGroupKey) -> None:
@@ -183,13 +177,11 @@ def _activate_group(self, key: TaskGroupKey) -> None:
183177
self._active_heap_entries[key] = (finish, self._sequence)
184178

185179
def _first_valid_item(self, key: TaskGroupKey) -> SchedulableTask | None:
180+
self._purge_queue_head(key)
186181
queue = self._queues.get(key)
187-
if queue is None:
182+
if not queue:
188183
return None
189-
for item in queue:
190-
if item.task_id in self._queued and self._task_groups.get(item.task_id) == key:
191-
return item
192-
return None
184+
return queue[0]
193185

194186
def _purge_queue_head(self, key: TaskGroupKey) -> None:
195187
queue = self._queues.get(key)
@@ -200,3 +192,37 @@ def _purge_queue_head(self, key: TaskGroupKey) -> None:
200192
if item.task_id in self._queued and self._task_groups.get(item.task_id) == key:
201193
break
202194
queue.popleft()
195+
196+
def _increment_queue_accounting(self, item: SchedulableTask) -> None:
197+
key = item.group.key
198+
self._queued_by_group[key] += 1
199+
for resource, amount in item.resource_request.amounts.items():
200+
self._queued_resource_demand_by_group[key][resource] += amount
201+
self._queued_peer_demand_by_resource[resource] += amount
202+
203+
def _remove_queued_item(self, task_id: str) -> SchedulableTask | None:
204+
item = self._queued.pop(task_id, None)
205+
key = self._task_groups.pop(task_id, None)
206+
if item is None or key is None:
207+
return item
208+
self._decrement_queue_accounting(item, key)
209+
return item
210+
211+
def _decrement_queue_accounting(self, item: SchedulableTask, key: TaskGroupKey) -> None:
212+
self._queued_by_group[key] -= 1
213+
if self._queued_by_group[key] <= 0:
214+
del self._queued_by_group[key]
215+
216+
group_demand = self._queued_resource_demand_by_group.get(key)
217+
if group_demand is not None:
218+
for resource, amount in item.resource_request.amounts.items():
219+
group_demand[resource] -= amount
220+
if group_demand[resource] <= 0:
221+
del group_demand[resource]
222+
if not group_demand:
223+
del self._queued_resource_demand_by_group[key]
224+
225+
for resource, amount in item.resource_request.amounts.items():
226+
self._queued_peer_demand_by_resource[resource] -= amount
227+
if self._queued_peer_demand_by_resource[resource] <= 0:
228+
del self._queued_peer_demand_by_resource[resource]

packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from __future__ import annotations
55

66
from collections import Counter
7+
from collections.abc import ItemsView
78

89
from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView
910
from data_designer.engine.dataset_builders.scheduling.resources import (
1011
SchedulableTask,
12+
SchedulerResourceKey,
1113
SchedulerResourceRequest,
1214
TaskGroupKey,
1315
TaskGroupSpec,
@@ -16,6 +18,15 @@
1618
from data_designer.engine.dataset_builders.scheduling.task_model import Task
1719

1820

21+
class _FailIfScannedAmounts(dict[SchedulerResourceKey, int]):
22+
locked: bool = False
23+
24+
def items(self) -> ItemsView[SchedulerResourceKey, int]:
25+
if self.locked:
26+
raise AssertionError("QueueView should use incremental accounting for non-candidate tasks.")
27+
return super().items()
28+
29+
1930
def _task(column: str, row_index: int) -> Task:
2031
return Task(column=column, row_group=0, row_index=row_index, task_type="cell")
2132

@@ -118,14 +129,25 @@ def test_select_next_uses_scheduler_eligibility_callback() -> None:
118129

119130
def test_enqueue_is_idempotent_by_task_id() -> None:
120131
queue = FairTaskQueue()
121-
item = _item("a", 0)
132+
group = _group("a")
133+
task = _task("a", 0)
134+
item = SchedulableTask(
135+
task_id=stable_task_id(task),
136+
payload=task,
137+
group=group,
138+
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}),
139+
)
122140

123141
first = queue.enqueue([item])
124142
second = queue.enqueue([item])
143+
view = queue.view()
125144

126145
assert first == (item.task_id,)
127146
assert second == ()
128-
assert queue.view().queued_total == 1
147+
assert view.queued_total == 1
148+
assert view.queued_by_group == {group.key: 1}
149+
assert view.queued_resource_demand_by_group == {group.key: {"submission": 1, "llm_wait": 2}}
150+
assert view.queued_peer_demand_by_resource == {"submission": 1, "llm_wait": 2}
129151

130152

131153
def test_discard_where_removes_matching_tasks() -> None:
@@ -157,3 +179,63 @@ def test_queue_view_exposes_group_and_resource_demand() -> None:
157179
assert view.queued_by_group[group.key] == 1
158180
assert view.queued_resource_demand_by_group[group.key]["llm_wait"] == 1
159181
assert view.first_candidate_resources_by_group[group.key]["submission"] == 1
182+
183+
184+
def test_queue_view_updates_incremental_accounting_after_removals() -> None:
185+
queue = FairTaskQueue()
186+
first_group = _group("a")
187+
second_group = _group("b")
188+
first = SchedulableTask(
189+
task_id=stable_task_id(_task("a", 0)),
190+
payload=_task("a", 0),
191+
group=first_group,
192+
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 2}),
193+
)
194+
second = SchedulableTask(
195+
task_id=stable_task_id(_task("b", 0)),
196+
payload=_task("b", 0),
197+
group=second_group,
198+
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 3}),
199+
)
200+
third = SchedulableTask(
201+
task_id=stable_task_id(_task("b", 1)),
202+
payload=_task("b", 1),
203+
group=second_group,
204+
resource_request=SchedulerResourceRequest({"submission": 1, "local": 1}),
205+
)
206+
queue.enqueue([first, second, third])
207+
208+
queue.discard(first.task_id)
209+
committed = _select_and_commit(queue)
210+
211+
assert committed == second
212+
view = queue.view()
213+
assert view.queued_total == 1
214+
assert first_group.key not in view.queued_by_group
215+
assert view.queued_by_group == {second_group.key: 1}
216+
assert view.queued_resource_demand_by_group == {second_group.key: {"submission": 1, "local": 1}}
217+
assert view.queued_peer_demand_by_resource == {"submission": 1, "local": 1}
218+
219+
220+
def test_queue_view_uses_incremental_accounting_for_non_candidate_tasks() -> None:
221+
queue = FairTaskQueue()
222+
group = _group("a")
223+
first = _item("a", 0, group)
224+
amounts = _FailIfScannedAmounts({"submission": 1})
225+
task = _task("a", 1)
226+
second = SchedulableTask(
227+
task_id=stable_task_id(task),
228+
payload=task,
229+
group=group,
230+
resource_request=SchedulerResourceRequest(amounts),
231+
)
232+
queue.enqueue([first, second])
233+
amounts.locked = True
234+
235+
view = queue.view()
236+
237+
assert view.queued_total == 2
238+
assert view.queued_by_group == {group.key: 2}
239+
assert view.queued_resource_demand_by_group == {group.key: {"submission": 2}}
240+
assert view.first_candidate_resources_by_group == {group.key: {"submission": 1}}
241+
assert view.queued_peer_demand_by_resource == {"submission": 2}

0 commit comments

Comments
 (0)