Skip to content

Commit 5981281

Browse files
committed
feat: make async in-flight task cap configurable
Signed-off-by: Eric W. Tramel <eric.tramel@gmail.com>
1 parent 000fc09 commit 5981281

14 files changed

Lines changed: 103 additions & 58 deletions

File tree

packages/data-designer-config/src/data_designer/config/run_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class RunConfig(ConfigBase):
133133
buffer_size: Number of records to process in each batch during dataset generation.
134134
A batch is processed end-to-end (column generation, post-batch processors, and writing the batch
135135
to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000.
136+
max_in_flight_tasks: Maximum number of async scheduler tasks that may hold task
137+
leases at once. Tasks may be executing, awaiting I/O, or waiting on model
138+
request admission. Model API request concurrency is controlled separately by
139+
``max_parallel_requests``. Must be >= 1. Default is 1024.
136140
non_inference_max_parallel_workers: Maximum number of worker threads used for non-inference
137141
cell-by-cell generators. Must be >= 1. Default is 4.
138142
max_conversation_restarts: Maximum number of full conversation restarts permitted when
@@ -165,6 +169,14 @@ class RunConfig(ConfigBase):
165169
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
166170
shutdown_error_window: int = Field(default=10, ge=1)
167171
buffer_size: int = Field(default=1000, gt=0)
172+
max_in_flight_tasks: int = Field(
173+
default=1024,
174+
ge=1,
175+
description=(
176+
"Maximum number of async scheduler tasks that may hold task leases at once. "
177+
"Model API request concurrency is controlled separately by max_parallel_requests."
178+
),
179+
)
168180
non_inference_max_parallel_workers: int = Field(default=4, ge=1)
169181
max_conversation_restarts: int = Field(default=5, ge=0)
170182
max_conversation_correction_steps: int = Field(default=0, ge=0)

packages/data-designer-config/tests/config/test_run_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ def test_run_config_accepts_native_renderer() -> None:
2424
assert JinjaRenderingEngine(run_config.jinja_rendering_engine) == JinjaRenderingEngine.NATIVE
2525

2626

27+
def test_run_config_defaults_max_in_flight_tasks_to_1024() -> None:
28+
assert RunConfig().max_in_flight_tasks == 1024
29+
30+
31+
def test_run_config_accepts_custom_max_in_flight_tasks() -> None:
32+
run_config = RunConfig(max_in_flight_tasks=2048)
33+
34+
assert run_config.max_in_flight_tasks == 2048
35+
36+
37+
def test_run_config_rejects_invalid_max_in_flight_tasks() -> None:
38+
with pytest.raises(ValidationError, match="max_in_flight_tasks"):
39+
RunConfig(max_in_flight_tasks=0)
40+
41+
2742
def test_run_config_throttle_shim_rejects_unknown_legacy_fields() -> None:
2843
with pytest.raises(ValidationError, match="max_concurrent_requests"):
2944
RunConfig(throttle={"max_concurrent_requests": 1})

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
stable_task_id,
3939
)
4040
from data_designer.engine.dataset_builders.scheduling.task_admission import (
41+
DEFAULT_IN_FLIGHT_TASK_CAPACITY,
4142
TaskAdmissionConfig,
4243
TaskAdmissionController,
4344
TaskAdmissionDenied,
@@ -76,7 +77,7 @@
7677

7778
logger = logging.getLogger(__name__)
7879

79-
DEFAULT_TASK_POOL_SIZE: int = 256
80+
DEFAULT_TASK_POOL_SIZE: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY
8081
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2
8182
MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2
8283

@@ -144,7 +145,7 @@ def __init__(
144145
buffer_manager: RowGroupBufferManager | None = None,
145146
*,
146147
max_concurrent_row_groups: int = 3,
147-
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
148+
max_in_flight_tasks: int = DEFAULT_TASK_POOL_SIZE,
148149
max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE,
149150
task_admission_config: TaskAdmissionConfig | None = None,
150151
salvage_max_rounds: int = 2,
@@ -183,8 +184,8 @@ def __init__(
183184
model_group_limit_cap=max_model_task_admission,
184185
)
185186
admission_config = task_admission_config or TaskAdmissionConfig(
186-
submission_capacity=max_submitted_tasks,
187-
resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks},
187+
submission_capacity=max_in_flight_tasks,
188+
resource_limits={"llm_wait": max_model_task_admission, "local": max_in_flight_tasks},
188189
)
189190
self._task_admission = TaskAdmissionController(admission_config)
190191
self._task_admission_config = admission_config
@@ -277,7 +278,7 @@ def __init__(
277278
# Pre-compute row-group sizes for O(1) lookup
278279
self._rg_size_map: dict[int, int] = dict(row_groups)
279280
self._max_concurrent_row_groups = max_concurrent_row_groups
280-
self._max_submitted_tasks = max_submitted_tasks
281+
self._max_in_flight_tasks = max_in_flight_tasks
281282
self._max_model_task_admission = max_model_task_admission
282283
self._num_records = num_records
283284
self._buffer_size = buffer_size
@@ -910,7 +911,7 @@ def _adaptive_row_group_block_reason(self) -> str | None:
910911
if not self._row_group_row_guard_allows(next_size):
911912
return "max_admitted_rows"
912913
queue_view = self._fair_queue.view()
913-
queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2)
914+
queue_guard = max(self._max_in_flight_tasks * 4, self._max_model_task_admission * 2)
914915
if queue_view.queued_total >= queue_guard:
915916
return "queued_task_guardrail"
916917
task_view = self._task_admission.view()
@@ -1907,7 +1908,7 @@ def capacity_plan(self) -> AsyncCapacityPlan:
19071908
max_admitted_rows=self._adaptive_max_admitted_rows,
19081909
blocked_reasons=dict(self._row_group_admission_blocked_reasons),
19091910
),
1910-
submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"),
1911+
submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"),
19111912
task_resource_limits=CapacityValue(
19121913
value=dict(self._task_admission_config.resource_limits),
19131914
source="engine_internal_config",

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,14 +1024,21 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
10241024
# at the model-call boundary.
10251025
aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests()
10261026

1027+
max_in_flight_tasks = self._resource_provider.run_config.max_in_flight_tasks
1028+
max_model_task_admission = max(
1029+
DEFAULT_TASK_POOL_SIZE,
1030+
max_in_flight_tasks,
1031+
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate,
1032+
)
1033+
10271034
scheduler = AsyncTaskScheduler(
10281035
generators=gen_map,
10291036
graph=graph,
10301037
tracker=tracker,
10311038
row_groups=row_groups,
10321039
buffer_manager=buffer_manager,
1033-
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
1034-
max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate),
1040+
max_in_flight_tasks=max_in_flight_tasks,
1041+
max_model_task_admission=max_model_task_admission,
10351042
on_finalize_row_group=on_finalize_row_group,
10361043
on_seeds_complete=(
10371044
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None

packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@
3636
"unknown_lease",
3737
]
3838
RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192
39+
DEFAULT_IN_FLIGHT_TASK_CAPACITY = 1024
3940

4041

4142
@dataclass(frozen=True)
4243
class TaskAdmissionConfig:
4344
"""Engine-internal scheduler task-stage admission configuration."""
4445

45-
submission_capacity: int = 256
46+
submission_capacity: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY
4647
resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict)
4748
bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None
4849

packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __init__(self, **kwargs: object) -> None:
205205
model_registry.request_admission = request_admission
206206
provider = SimpleNamespace(
207207
model_registry=model_registry,
208-
run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False),
208+
run_config=SimpleNamespace(max_in_flight_tasks=1536, progress_interval=5.0, progress_bar=False),
209209
)
210210
processor_runner = MagicMock()
211211
processor_runner.has_processors_for.return_value = False
@@ -222,6 +222,8 @@ def __init__(self, **kwargs: object) -> None:
222222

223223
assert captured_kwargs["request_pressure_provider"] is request_admission
224224
assert captured_kwargs["request_pressure_advisory"] is True
225+
assert captured_kwargs["max_in_flight_tasks"] == 1536
226+
assert captured_kwargs["max_model_task_admission"] == 1536
225227

226228

227229
# -- Test that existing sync path is unaffected --------------------------------

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ async def test_scheduler_stateful_generator_serializes() -> None:
856856

857857

858858
@pytest.mark.asyncio(loop_scope="session")
859-
async def test_scheduler_bounded_submission() -> None:
860-
"""Submitted task count respects max_submitted_tasks."""
859+
async def test_scheduler_bounded_in_flight_tasks() -> None:
860+
"""In-flight task count respects max_in_flight_tasks."""
861861
provider = _mock_provider()
862862

863863
# Use a pipeline with many cells and low submission limit
@@ -883,7 +883,7 @@ async def test_scheduler_bounded_submission() -> None:
883883
graph=graph,
884884
tracker=tracker,
885885
row_groups=row_groups,
886-
max_submitted_tasks=2,
886+
max_in_flight_tasks=2,
887887
)
888888
await scheduler.run()
889889

@@ -1821,22 +1821,22 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None:
18211821
row_groups = [(0, 3)]
18221822
tracker = CompletionTracker.with_graph(graph, row_groups)
18231823

1824-
max_submitted = 2
1824+
max_in_flight = 2
18251825
max_llm_wait = 2
18261826
scheduler = AsyncTaskScheduler(
18271827
generators=generators,
18281828
graph=graph,
18291829
tracker=tracker,
18301830
row_groups=row_groups,
1831-
max_submitted_tasks=max_submitted,
1831+
max_in_flight_tasks=max_in_flight,
18321832
max_model_task_admission=max_llm_wait,
18331833
)
18341834
await scheduler.run()
18351835

18361836
assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"])
18371837

18381838
snapshot = scheduler.task_admission_snapshot()
1839-
assert snapshot.resources_available["submission"] == max_submitted
1839+
assert snapshot.resources_available["submission"] == max_in_flight
18401840
assert snapshot.resources_available["llm_wait"] == max_llm_wait
18411841

18421842

@@ -1867,7 +1867,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18671867
graph=graph,
18681868
tracker=tracker,
18691869
row_groups=row_groups,
1870-
max_submitted_tasks=2,
1870+
max_in_flight_tasks=2,
18711871
max_model_task_admission=max_llm_wait,
18721872
)
18731873
await scheduler.run()
@@ -1880,7 +1880,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18801880

18811881
@pytest.mark.asyncio(loop_scope="session")
18821882
async def test_scheduler_deadlock_regression() -> None:
1883-
"""max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
1883+
"""max_in_flight_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
18841884
provider = _mock_provider()
18851885
configs = [
18861886
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
@@ -1904,7 +1904,7 @@ async def test_scheduler_deadlock_regression() -> None:
19041904
graph=graph,
19051905
tracker=tracker,
19061906
row_groups=row_groups,
1907-
max_submitted_tasks=1,
1907+
max_in_flight_tasks=1,
19081908
max_model_task_admission=1,
19091909
)
19101910

@@ -2379,23 +2379,23 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None:
23792379
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
23802380
buffer_mgr = RowGroupBufferManager(storage)
23812381

2382-
max_submitted = 4
2382+
max_in_flight = 4
23832383
max_llm_wait = 2
23842384
scheduler = AsyncTaskScheduler(
23852385
generators=generators,
23862386
graph=graph,
23872387
tracker=tracker,
23882388
row_groups=row_groups,
23892389
buffer_manager=buffer_mgr,
2390-
max_submitted_tasks=max_submitted,
2390+
max_in_flight_tasks=max_in_flight,
23912391
max_model_task_admission=max_llm_wait,
23922392
)
23932393
await scheduler.run()
23942394

23952395
assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"])
23962396

23972397
snapshot = scheduler.task_admission_snapshot()
2398-
assert snapshot.resources_available["submission"] == max_submitted
2398+
assert snapshot.resources_available["submission"] == max_in_flight
23992399
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24002400

24012401

@@ -2441,15 +2441,15 @@ async def agenerate(self, data: dict) -> dict:
24412441
row_groups = [(0, 2)]
24422442
tracker = CompletionTracker.with_graph(graph, row_groups)
24432443

2444-
max_submitted = 4
2444+
max_in_flight = 4
24452445
max_llm_wait = 2
24462446
sink = InMemoryAdmissionEventSink()
24472447
scheduler = AsyncTaskScheduler(
24482448
generators=generators,
24492449
graph=graph,
24502450
tracker=tracker,
24512451
row_groups=row_groups,
2452-
max_submitted_tasks=max_submitted,
2452+
max_in_flight_tasks=max_in_flight,
24532453
max_model_task_admission=max_llm_wait,
24542454
scheduler_event_sink=sink,
24552455
)
@@ -2462,7 +2462,7 @@ async def agenerate(self, data: dict) -> dict:
24622462
await run_task
24632463

24642464
snapshot = scheduler.task_admission_snapshot()
2465-
assert snapshot.resources_available["submission"] == max_submitted
2465+
assert snapshot.resources_available["submission"] == max_in_flight
24662466
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24672467
assert "cancelled" in [event.event_kind for event in sink.scheduler_events]
24682468
assert all(event.snapshot is not None for event in sink.scheduler_events)
@@ -2684,7 +2684,7 @@ async def test_scheduler_fair_admission_across_ready_columns() -> None:
26842684
graph=graph,
26852685
tracker=tracker,
26862686
row_groups=row_groups,
2687-
max_submitted_tasks=4,
2687+
max_in_flight_tasks=4,
26882688
trace=True,
26892689
)
26902690

@@ -2758,7 +2758,7 @@ async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
27582758
graph=graph,
27592759
tracker=tracker,
27602760
row_groups=row_groups,
2761-
max_submitted_tasks=8,
2761+
max_in_flight_tasks=8,
27622762
max_concurrent_row_groups=2,
27632763
trace=True,
27642764
)
@@ -2806,7 +2806,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None:
28062806
graph=graph,
28072807
tracker=tracker,
28082808
row_groups=row_groups,
2809-
max_submitted_tasks=4,
2809+
max_in_flight_tasks=4,
28102810
max_model_task_admission=4,
28112811
trace=True,
28122812
)
@@ -2877,7 +2877,7 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None:
28772877
tracker=tracker,
28782878
row_groups=row_groups,
28792879
buffer_manager=buffer_manager,
2880-
max_submitted_tasks=4,
2880+
max_in_flight_tasks=4,
28812881
trace=True,
28822882
)
28832883
await asyncio.wait_for(scheduler.run(), timeout=10.0)
@@ -2925,7 +2925,7 @@ async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None:
29252925
tracker=tracker,
29262926
row_groups=row_groups,
29272927
max_concurrent_row_groups=2,
2928-
max_submitted_tasks=2,
2928+
max_in_flight_tasks=2,
29292929
trace=True,
29302930
num_records=12,
29312931
buffer_size=3,
@@ -3023,7 +3023,7 @@ async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None:
30233023
tracker=tracker,
30243024
row_groups=row_groups,
30253025
max_concurrent_row_groups=1,
3026-
max_submitted_tasks=2,
3026+
max_in_flight_tasks=2,
30273027
max_model_task_admission=1,
30283028
scheduler_event_sink=sink,
30293029
num_records=2,
@@ -3089,7 +3089,7 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon
30893089
tracker=tracker,
30903090
row_groups=row_groups,
30913091
max_concurrent_row_groups=4,
3092-
max_submitted_tasks=4,
3092+
max_in_flight_tasks=4,
30933093
max_model_task_admission=4,
30943094
adaptive_row_group_admission=True,
30953095
adaptive_row_group_initial_target=1,

packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def test_dataset_builder_validate_column_configs(
421421

422422
def test_run_config_default_non_inference_max_parallel_workers() -> None:
423423
run_config = RunConfig()
424+
assert run_config.max_in_flight_tasks == 1024
424425
assert run_config.non_inference_max_parallel_workers == 4
425426

426427

packages/data-designer/src/data_designer/interface/data_designer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,9 @@ def set_run_config(self, run_config: RunConfig) -> None:
605605
606606
Args:
607607
run_config: A RunConfig instance containing runtime settings such as
608-
early shutdown behavior, batch sizing via `buffer_size`, and non-inference worker
609-
concurrency via `non_inference_max_parallel_workers`.
608+
early shutdown behavior, batch sizing via `buffer_size`, async task lease
609+
capacity via `max_in_flight_tasks`, and non-inference worker concurrency via
610+
`non_inference_max_parallel_workers`.
610611
611612
Notes:
612613
When `disable_early_shutdown=True`, DataDesigner will never terminate generation early

0 commit comments

Comments
 (0)