Skip to content

Commit eccce53

Browse files
committed
feat: make async in-flight task cap configurable
Signed-off-by: Eric W. Tramel <eric.tramel@gmail.com>
1 parent b6de38d commit eccce53

10 files changed

Lines changed: 91 additions & 47 deletions

File tree

packages/data-designer-config/src/data_designer/config/run_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class RunConfig(ConfigBase):
133133
buffer_size: Number of records to process in each batch during dataset generation.
134134
A batch is processed end-to-end (column generation, post-batch processors, and writing the batch
135135
to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000.
136+
max_in_flight_tasks: Maximum number of async scheduler tasks that may hold task
137+
leases at once. Tasks may be executing, awaiting I/O, or waiting on model
138+
request admission. Model API request concurrency is controlled separately by
139+
``max_parallel_requests``. Must be >= 1. Default is 1024.
136140
non_inference_max_parallel_workers: Maximum number of worker threads used for non-inference
137141
cell-by-cell generators. Must be >= 1. Default is 4.
138142
max_conversation_restarts: Maximum number of full conversation restarts permitted when
@@ -168,6 +172,14 @@ class RunConfig(ConfigBase):
168172
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
169173
shutdown_error_window: int = Field(default=10, ge=1)
170174
buffer_size: int = Field(default=1000, gt=0)
175+
max_in_flight_tasks: int = Field(
176+
default=1024,
177+
ge=1,
178+
description=(
179+
"Maximum number of async scheduler tasks that may hold task leases at once. "
180+
"Model API request concurrency is controlled separately by max_parallel_requests."
181+
),
182+
)
171183
non_inference_max_parallel_workers: int = Field(default=4, ge=1)
172184
max_conversation_restarts: int = Field(default=5, ge=0)
173185
max_conversation_correction_steps: int = Field(default=0, ge=0)

packages/data-designer-config/tests/config/test_run_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ def test_run_config_accepts_disabled_dropped_column_preservation() -> None:
3333
assert run_config.preserve_dropped_columns is False
3434

3535

36+
def test_run_config_defaults_max_in_flight_tasks_to_1024() -> None:
37+
assert RunConfig().max_in_flight_tasks == 1024
38+
39+
40+
def test_run_config_accepts_custom_max_in_flight_tasks() -> None:
41+
run_config = RunConfig(max_in_flight_tasks=2048)
42+
43+
assert run_config.max_in_flight_tasks == 2048
44+
45+
46+
def test_run_config_rejects_invalid_max_in_flight_tasks() -> None:
47+
with pytest.raises(ValidationError, match="max_in_flight_tasks"):
48+
RunConfig(max_in_flight_tasks=0)
49+
50+
3651
def test_run_config_throttle_shim_rejects_unknown_legacy_fields() -> None:
3752
with pytest.raises(ValidationError, match="max_concurrent_requests"):
3853
RunConfig(throttle={"max_concurrent_requests": 1})

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
stable_task_id,
3939
)
4040
from data_designer.engine.dataset_builders.scheduling.task_admission import (
41+
DEFAULT_IN_FLIGHT_TASK_CAPACITY,
4142
TaskAdmissionConfig,
4243
TaskAdmissionController,
4344
TaskAdmissionDenied,
@@ -76,8 +77,6 @@
7677

7778
logger = logging.getLogger(__name__)
7879

79-
DEFAULT_TASK_POOL_SIZE: int = 256
80-
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2
8180
MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2
8281

8382
# Degraded-provider WARN: emit at most one warning per interval when the
@@ -144,8 +143,8 @@ def __init__(
144143
buffer_manager: RowGroupBufferManager | None = None,
145144
*,
146145
max_concurrent_row_groups: int = 3,
147-
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
148-
max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE,
146+
max_in_flight_tasks: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
147+
max_model_task_admission: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
149148
task_admission_config: TaskAdmissionConfig | None = None,
150149
salvage_max_rounds: int = 2,
151150
on_finalize_row_group: Callable[[int], None] | None = None,
@@ -183,8 +182,8 @@ def __init__(
183182
model_group_limit_cap=max_model_task_admission,
184183
)
185184
admission_config = task_admission_config or TaskAdmissionConfig(
186-
submission_capacity=max_submitted_tasks,
187-
resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks},
185+
submission_capacity=max_in_flight_tasks,
186+
resource_limits={"llm_wait": max_model_task_admission},
188187
)
189188
self._task_admission = TaskAdmissionController(admission_config)
190189
self._task_admission_config = admission_config
@@ -277,7 +276,7 @@ def __init__(
277276
# Pre-compute row-group sizes for O(1) lookup
278277
self._rg_size_map: dict[int, int] = dict(row_groups)
279278
self._max_concurrent_row_groups = max_concurrent_row_groups
280-
self._max_submitted_tasks = max_submitted_tasks
279+
self._max_in_flight_tasks = max_in_flight_tasks
281280
self._max_model_task_admission = max_model_task_admission
282281
self._num_records = num_records
283282
self._buffer_size = buffer_size
@@ -910,7 +909,7 @@ def _adaptive_row_group_block_reason(self) -> str | None:
910909
if not self._row_group_row_guard_allows(next_size):
911910
return "max_admitted_rows"
912911
queue_view = self._fair_queue.view()
913-
queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2)
912+
queue_guard = self._max_in_flight_tasks * 4
914913
if queue_view.queued_total >= queue_guard:
915914
return "queued_task_guardrail"
916915
task_view = self._task_admission.view()
@@ -1907,7 +1906,7 @@ def capacity_plan(self) -> AsyncCapacityPlan:
19071906
max_admitted_rows=self._adaptive_max_admitted_rows,
19081907
blocked_reasons=dict(self._row_group_admission_blocked_reasons),
19091908
),
1910-
submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"),
1909+
submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"),
19111910
task_resource_limits=CapacityValue(
19121911
value=dict(self._task_admission_config.resource_limits),
19131912
source="engine_internal_config",

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@
8484
import asyncio
8585

8686
from data_designer.engine.dataset_builders.async_scheduler import (
87-
DEFAULT_TASK_POOL_SIZE,
88-
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER,
8987
AsyncTaskScheduler,
9088
)
9189
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
@@ -1055,19 +1053,17 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
10551053
df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id, strict_row_count=True)
10561054
buffer_manager.replace_dataframe(rg_id, df)
10571055

1058-
# Coarse upper bound used only for scheduler task-stage model admission.
1059-
# Concrete provider/model request capacity is enforced by request admission
1060-
# at the model-call boundary.
1061-
aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests()
1056+
max_in_flight_tasks = self._resource_provider.run_config.max_in_flight_tasks
1057+
max_model_task_admission = max_in_flight_tasks
10621058

10631059
scheduler = AsyncTaskScheduler(
10641060
generators=gen_map,
10651061
graph=graph,
10661062
tracker=tracker,
10671063
row_groups=row_groups,
10681064
buffer_manager=buffer_manager,
1069-
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
1070-
max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate),
1065+
max_in_flight_tasks=max_in_flight_tasks,
1066+
max_model_task_admission=max_model_task_admission,
10711067
on_finalize_row_group=on_finalize_row_group,
10721068
on_seeds_complete=(
10731069
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None

packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@
3636
"unknown_lease",
3737
]
3838
RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192
39+
DEFAULT_IN_FLIGHT_TASK_CAPACITY = 1024
3940

4041

4142
@dataclass(frozen=True)
4243
class TaskAdmissionConfig:
4344
"""Engine-internal scheduler task-stage admission configuration."""
4445

45-
submission_capacity: int = 256
46+
submission_capacity: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY
4647
resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict)
4748
bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None
4849

packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,13 @@ def __init__(self, **kwargs: object) -> None:
201201
monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler)
202202
request_admission = object()
203203
model_registry = MagicMock()
204-
model_registry.get_aggregate_max_parallel_requests.return_value = 2
204+
model_registry.get_aggregate_max_parallel_requests.side_effect = AssertionError(
205+
"model task admission should follow max_in_flight_tasks directly"
206+
)
205207
model_registry.request_admission = request_admission
206208
provider = SimpleNamespace(
207209
model_registry=model_registry,
208-
run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False),
210+
run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False),
209211
)
210212
processor_runner = MagicMock()
211213
processor_runner.has_processors_for.return_value = False
@@ -222,6 +224,8 @@ def __init__(self, **kwargs: object) -> None:
222224

223225
assert captured_kwargs["request_pressure_provider"] is request_admission
224226
assert captured_kwargs["request_pressure_advisory"] is True
227+
assert captured_kwargs["max_in_flight_tasks"] == 64
228+
assert captured_kwargs["max_model_task_admission"] == 64
225229

226230

227231
# -- Test that existing sync path is unaffected --------------------------------

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ async def test_scheduler_stateful_generator_serializes() -> None:
856856

857857

858858
@pytest.mark.asyncio(loop_scope="session")
859-
async def test_scheduler_bounded_submission() -> None:
860-
"""Submitted task count respects max_submitted_tasks."""
859+
async def test_scheduler_bounded_in_flight_tasks() -> None:
860+
"""In-flight task count respects max_in_flight_tasks."""
861861
provider = _mock_provider()
862862

863863
# Use a pipeline with many cells and low submission limit
@@ -883,7 +883,7 @@ async def test_scheduler_bounded_submission() -> None:
883883
graph=graph,
884884
tracker=tracker,
885885
row_groups=row_groups,
886-
max_submitted_tasks=2,
886+
max_in_flight_tasks=2,
887887
)
888888
await scheduler.run()
889889

@@ -1821,22 +1821,22 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None:
18211821
row_groups = [(0, 3)]
18221822
tracker = CompletionTracker.with_graph(graph, row_groups)
18231823

1824-
max_submitted = 2
1824+
max_in_flight = 2
18251825
max_llm_wait = 2
18261826
scheduler = AsyncTaskScheduler(
18271827
generators=generators,
18281828
graph=graph,
18291829
tracker=tracker,
18301830
row_groups=row_groups,
1831-
max_submitted_tasks=max_submitted,
1831+
max_in_flight_tasks=max_in_flight,
18321832
max_model_task_admission=max_llm_wait,
18331833
)
18341834
await scheduler.run()
18351835

18361836
assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"])
18371837

18381838
snapshot = scheduler.task_admission_snapshot()
1839-
assert snapshot.resources_available["submission"] == max_submitted
1839+
assert snapshot.resources_available["submission"] == max_in_flight
18401840
assert snapshot.resources_available["llm_wait"] == max_llm_wait
18411841

18421842

@@ -1867,7 +1867,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18671867
graph=graph,
18681868
tracker=tracker,
18691869
row_groups=row_groups,
1870-
max_submitted_tasks=2,
1870+
max_in_flight_tasks=2,
18711871
max_model_task_admission=max_llm_wait,
18721872
)
18731873
await scheduler.run()
@@ -1880,7 +1880,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18801880

18811881
@pytest.mark.asyncio(loop_scope="session")
18821882
async def test_scheduler_deadlock_regression() -> None:
1883-
"""max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
1883+
"""max_in_flight_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
18841884
provider = _mock_provider()
18851885
configs = [
18861886
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
@@ -1904,7 +1904,7 @@ async def test_scheduler_deadlock_regression() -> None:
19041904
graph=graph,
19051905
tracker=tracker,
19061906
row_groups=row_groups,
1907-
max_submitted_tasks=1,
1907+
max_in_flight_tasks=1,
19081908
max_model_task_admission=1,
19091909
)
19101910

@@ -2379,23 +2379,23 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None:
23792379
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
23802380
buffer_mgr = RowGroupBufferManager(storage)
23812381

2382-
max_submitted = 4
2382+
max_in_flight = 4
23832383
max_llm_wait = 2
23842384
scheduler = AsyncTaskScheduler(
23852385
generators=generators,
23862386
graph=graph,
23872387
tracker=tracker,
23882388
row_groups=row_groups,
23892389
buffer_manager=buffer_mgr,
2390-
max_submitted_tasks=max_submitted,
2390+
max_in_flight_tasks=max_in_flight,
23912391
max_model_task_admission=max_llm_wait,
23922392
)
23932393
await scheduler.run()
23942394

23952395
assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"])
23962396

23972397
snapshot = scheduler.task_admission_snapshot()
2398-
assert snapshot.resources_available["submission"] == max_submitted
2398+
assert snapshot.resources_available["submission"] == max_in_flight
23992399
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24002400

24012401

@@ -2441,15 +2441,15 @@ async def agenerate(self, data: dict) -> dict:
24412441
row_groups = [(0, 2)]
24422442
tracker = CompletionTracker.with_graph(graph, row_groups)
24432443

2444-
max_submitted = 4
2444+
max_in_flight = 4
24452445
max_llm_wait = 2
24462446
sink = InMemoryAdmissionEventSink()
24472447
scheduler = AsyncTaskScheduler(
24482448
generators=generators,
24492449
graph=graph,
24502450
tracker=tracker,
24512451
row_groups=row_groups,
2452-
max_submitted_tasks=max_submitted,
2452+
max_in_flight_tasks=max_in_flight,
24532453
max_model_task_admission=max_llm_wait,
24542454
scheduler_event_sink=sink,
24552455
)
@@ -2462,7 +2462,7 @@ async def agenerate(self, data: dict) -> dict:
24622462
await run_task
24632463

24642464
snapshot = scheduler.task_admission_snapshot()
2465-
assert snapshot.resources_available["submission"] == max_submitted
2465+
assert snapshot.resources_available["submission"] == max_in_flight
24662466
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24672467
assert "cancelled" in [event.event_kind for event in sink.scheduler_events]
24682468
assert all(event.snapshot is not None for event in sink.scheduler_events)
@@ -2684,7 +2684,7 @@ async def test_scheduler_fair_admission_across_ready_columns() -> None:
26842684
graph=graph,
26852685
tracker=tracker,
26862686
row_groups=row_groups,
2687-
max_submitted_tasks=4,
2687+
max_in_flight_tasks=4,
26882688
trace=True,
26892689
)
26902690

@@ -2758,7 +2758,7 @@ async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
27582758
graph=graph,
27592759
tracker=tracker,
27602760
row_groups=row_groups,
2761-
max_submitted_tasks=8,
2761+
max_in_flight_tasks=8,
27622762
max_concurrent_row_groups=2,
27632763
trace=True,
27642764
)
@@ -2806,7 +2806,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None:
28062806
graph=graph,
28072807
tracker=tracker,
28082808
row_groups=row_groups,
2809-
max_submitted_tasks=4,
2809+
max_in_flight_tasks=4,
28102810
max_model_task_admission=4,
28112811
trace=True,
28122812
)
@@ -2877,7 +2877,7 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None:
28772877
tracker=tracker,
28782878
row_groups=row_groups,
28792879
buffer_manager=buffer_manager,
2880-
max_submitted_tasks=4,
2880+
max_in_flight_tasks=4,
28812881
trace=True,
28822882
)
28832883
await asyncio.wait_for(scheduler.run(), timeout=10.0)
@@ -2925,7 +2925,7 @@ async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None:
29252925
tracker=tracker,
29262926
row_groups=row_groups,
29272927
max_concurrent_row_groups=2,
2928-
max_submitted_tasks=2,
2928+
max_in_flight_tasks=2,
29292929
trace=True,
29302930
num_records=12,
29312931
buffer_size=3,
@@ -3023,7 +3023,7 @@ async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None:
30233023
tracker=tracker,
30243024
row_groups=row_groups,
30253025
max_concurrent_row_groups=1,
3026-
max_submitted_tasks=2,
3026+
max_in_flight_tasks=2,
30273027
max_model_task_admission=1,
30283028
scheduler_event_sink=sink,
30293029
num_records=2,
@@ -3089,7 +3089,7 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon
30893089
tracker=tracker,
30903090
row_groups=row_groups,
30913091
max_concurrent_row_groups=4,
3092-
max_submitted_tasks=4,
3092+
max_in_flight_tasks=4,
30933093
max_model_task_admission=4,
30943094
adaptive_row_group_admission=True,
30953095
adaptive_row_group_initial_target=1,
@@ -3189,6 +3189,17 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N
31893189
assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated"
31903190

31913191

3192+
def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None:
3193+
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
3194+
scheduler._max_in_flight_tasks = 2
3195+
scheduler._max_model_task_admission = 100
3196+
scheduler._fair_queue = SimpleNamespace(
3197+
view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={})
3198+
)
3199+
3200+
assert scheduler._adaptive_row_group_block_reason() == "queued_task_guardrail"
3201+
3202+
31923203
@pytest.mark.asyncio(loop_scope="session")
31933204
async def test_scheduler_raises_when_ready_frontier_blocked_without_in_flight() -> None:
31943205
provider = _mock_provider()

0 commit comments

Comments
 (0)