Skip to content

Commit c911e2c

Browse files
authored
feat: make async in-flight task cap configurable (#699)
1 parent b30f802 commit c911e2c

10 files changed

Lines changed: 91 additions & 47 deletions

File tree

packages/data-designer-config/src/data_designer/config/run_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class RunConfig(ConfigBase):
133133
buffer_size: Number of records to process in each batch during dataset generation.
134134
A batch is processed end-to-end (column generation, post-batch processors, and writing the batch
135135
to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000.
136+
max_in_flight_tasks: Maximum number of async scheduler tasks that may hold task
137+
leases at once. Tasks may be executing, awaiting I/O, or waiting on model
138+
request admission. Model API request concurrency is controlled separately by
139+
``max_parallel_requests``. Must be >= 1. Default is 1024.
136140
non_inference_max_parallel_workers: Maximum number of worker threads used for non-inference
137141
cell-by-cell generators. Must be >= 1. Default is 4.
138142
max_conversation_restarts: Maximum number of full conversation restarts permitted when
@@ -168,6 +172,14 @@ class RunConfig(ConfigBase):
168172
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
169173
shutdown_error_window: int = Field(default=10, ge=1)
170174
buffer_size: int = Field(default=1000, gt=0)
175+
max_in_flight_tasks: int = Field(
176+
default=1024,
177+
ge=1,
178+
description=(
179+
"Maximum number of async scheduler tasks that may hold task leases at once. "
180+
"Model API request concurrency is controlled separately by max_parallel_requests."
181+
),
182+
)
171183
non_inference_max_parallel_workers: int = Field(default=4, ge=1)
172184
max_conversation_restarts: int = Field(default=5, ge=0)
173185
max_conversation_correction_steps: int = Field(default=0, ge=0)

packages/data-designer-config/tests/config/test_run_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ def test_run_config_accepts_disabled_dropped_column_preservation() -> None:
3333
assert run_config.preserve_dropped_columns is False
3434

3535

36+
def test_run_config_defaults_max_in_flight_tasks_to_1024() -> None:
37+
assert RunConfig().max_in_flight_tasks == 1024
38+
39+
40+
def test_run_config_accepts_custom_max_in_flight_tasks() -> None:
41+
run_config = RunConfig(max_in_flight_tasks=2048)
42+
43+
assert run_config.max_in_flight_tasks == 2048
44+
45+
46+
def test_run_config_rejects_invalid_max_in_flight_tasks() -> None:
47+
with pytest.raises(ValidationError, match="max_in_flight_tasks"):
48+
RunConfig(max_in_flight_tasks=0)
49+
50+
3651
def test_run_config_throttle_shim_rejects_unknown_legacy_fields() -> None:
3752
with pytest.raises(ValidationError, match="max_concurrent_requests"):
3853
RunConfig(throttle={"max_concurrent_requests": 1})

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
stable_task_id,
3939
)
4040
from data_designer.engine.dataset_builders.scheduling.task_admission import (
41+
DEFAULT_IN_FLIGHT_TASK_CAPACITY,
4142
TaskAdmissionConfig,
4243
TaskAdmissionController,
4344
TaskAdmissionDenied,
@@ -77,8 +78,6 @@
7778

7879
logger = logging.getLogger(__name__)
7980

80-
DEFAULT_TASK_POOL_SIZE: int = 256
81-
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2
8281
MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2
8382

8483
# Degraded-provider WARN: emit at most one warning per interval when the
@@ -145,8 +144,8 @@ def __init__(
145144
buffer_manager: RowGroupBufferManager | None = None,
146145
*,
147146
max_concurrent_row_groups: int = 3,
148-
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
149-
max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE,
147+
max_in_flight_tasks: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
148+
max_model_task_admission: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
150149
task_admission_config: TaskAdmissionConfig | None = None,
151150
salvage_max_rounds: int = 2,
152151
on_finalize_row_group: Callable[[int], None] | None = None,
@@ -184,8 +183,8 @@ def __init__(
184183
model_group_limit_cap=max_model_task_admission,
185184
)
186185
admission_config = task_admission_config or TaskAdmissionConfig(
187-
submission_capacity=max_submitted_tasks,
188-
resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks},
186+
submission_capacity=max_in_flight_tasks,
187+
resource_limits={"llm_wait": max_model_task_admission},
189188
bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(),
190189
)
191190
self._task_admission = TaskAdmissionController(admission_config)
@@ -279,7 +278,7 @@ def __init__(
279278
# Pre-compute row-group sizes for O(1) lookup
280279
self._rg_size_map: dict[int, int] = dict(row_groups)
281280
self._max_concurrent_row_groups = max_concurrent_row_groups
282-
self._max_submitted_tasks = max_submitted_tasks
281+
self._max_in_flight_tasks = max_in_flight_tasks
283282
self._max_model_task_admission = max_model_task_admission
284283
self._num_records = num_records
285284
self._buffer_size = buffer_size
@@ -912,7 +911,7 @@ def _adaptive_row_group_block_reason(self) -> str | None:
912911
if not self._row_group_row_guard_allows(next_size):
913912
return "max_admitted_rows"
914913
queue_view = self._fair_queue.view()
915-
queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2)
914+
queue_guard = self._max_in_flight_tasks * 4
916915
if queue_view.queued_total >= queue_guard:
917916
return "queued_task_guardrail"
918917
task_view = self._task_admission.view()
@@ -1914,7 +1913,7 @@ def capacity_plan(self) -> AsyncCapacityPlan:
19141913
max_admitted_rows=self._adaptive_max_admitted_rows,
19151914
blocked_reasons=dict(self._row_group_admission_blocked_reasons),
19161915
),
1917-
submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"),
1916+
submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"),
19181917
task_resource_limits=CapacityValue(
19191918
value=dict(self._task_admission_config.resource_limits),
19201919
source="engine_internal_config",

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@
8484
import asyncio
8585

8686
from data_designer.engine.dataset_builders.async_scheduler import (
87-
DEFAULT_TASK_POOL_SIZE,
88-
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER,
8987
AsyncTaskScheduler,
9088
)
9189
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
@@ -1055,19 +1053,17 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
10551053
df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id, strict_row_count=True)
10561054
buffer_manager.replace_dataframe(rg_id, df)
10571055

1058-
# Coarse upper bound used only for scheduler task-stage model admission.
1059-
# Concrete provider/model request capacity is enforced by request admission
1060-
# at the model-call boundary.
1061-
aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests()
1056+
max_in_flight_tasks = self._resource_provider.run_config.max_in_flight_tasks
1057+
max_model_task_admission = max_in_flight_tasks
10621058

10631059
scheduler = AsyncTaskScheduler(
10641060
generators=gen_map,
10651061
graph=graph,
10661062
tracker=tracker,
10671063
row_groups=row_groups,
10681064
buffer_manager=buffer_manager,
1069-
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
1070-
max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate),
1065+
max_in_flight_tasks=max_in_flight_tasks,
1066+
max_model_task_admission=max_model_task_admission,
10711067
on_finalize_row_group=on_finalize_row_group,
10721068
on_seeds_complete=(
10731069
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None

packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,14 @@
3636
"unknown_lease",
3737
]
3838
RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192
39+
DEFAULT_IN_FLIGHT_TASK_CAPACITY = 1024
3940

4041

4142
@dataclass(frozen=True)
4243
class TaskAdmissionConfig:
4344
"""Engine-internal scheduler task-stage admission configuration."""
4445

45-
submission_capacity: int = 256
46+
submission_capacity: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY
4647
resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict)
4748
bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None
4849

packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,13 @@ def __init__(self, **kwargs: object) -> None:
201201
monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler)
202202
request_admission = object()
203203
model_registry = MagicMock()
204-
model_registry.get_aggregate_max_parallel_requests.return_value = 2
204+
model_registry.get_aggregate_max_parallel_requests.side_effect = AssertionError(
205+
"model task admission should follow max_in_flight_tasks directly"
206+
)
205207
model_registry.request_admission = request_admission
206208
provider = SimpleNamespace(
207209
model_registry=model_registry,
208-
run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False),
210+
run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False),
209211
)
210212
processor_runner = MagicMock()
211213
processor_runner.has_processors_for.return_value = False
@@ -222,6 +224,8 @@ def __init__(self, **kwargs: object) -> None:
222224

223225
assert captured_kwargs["request_pressure_provider"] is request_admission
224226
assert captured_kwargs["request_pressure_advisory"] is True
227+
assert captured_kwargs["max_in_flight_tasks"] == 64
228+
assert captured_kwargs["max_model_task_admission"] == 64
225229

226230

227231
# -- Test that existing sync path is unaffected --------------------------------

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -857,8 +857,8 @@ async def test_scheduler_stateful_generator_serializes() -> None:
857857

858858

859859
@pytest.mark.asyncio(loop_scope="session")
860-
async def test_scheduler_bounded_submission() -> None:
861-
"""Submitted task count respects max_submitted_tasks."""
860+
async def test_scheduler_bounded_in_flight_tasks() -> None:
861+
"""In-flight task count respects max_in_flight_tasks."""
862862
provider = _mock_provider()
863863

864864
# Use a pipeline with many cells and low submission limit
@@ -884,7 +884,7 @@ async def test_scheduler_bounded_submission() -> None:
884884
graph=graph,
885885
tracker=tracker,
886886
row_groups=row_groups,
887-
max_submitted_tasks=2,
887+
max_in_flight_tasks=2,
888888
)
889889
await scheduler.run()
890890

@@ -1822,22 +1822,22 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None:
18221822
row_groups = [(0, 3)]
18231823
tracker = CompletionTracker.with_graph(graph, row_groups)
18241824

1825-
max_submitted = 2
1825+
max_in_flight = 2
18261826
max_llm_wait = 2
18271827
scheduler = AsyncTaskScheduler(
18281828
generators=generators,
18291829
graph=graph,
18301830
tracker=tracker,
18311831
row_groups=row_groups,
1832-
max_submitted_tasks=max_submitted,
1832+
max_in_flight_tasks=max_in_flight,
18331833
max_model_task_admission=max_llm_wait,
18341834
)
18351835
await scheduler.run()
18361836

18371837
assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"])
18381838

18391839
snapshot = scheduler.task_admission_snapshot()
1840-
assert snapshot.resources_available["submission"] == max_submitted
1840+
assert snapshot.resources_available["submission"] == max_in_flight
18411841
assert snapshot.resources_available["llm_wait"] == max_llm_wait
18421842

18431843

@@ -1886,7 +1886,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18861886
graph=graph,
18871887
tracker=tracker,
18881888
row_groups=row_groups,
1889-
max_submitted_tasks=2,
1889+
max_in_flight_tasks=2,
18901890
max_model_task_admission=max_llm_wait,
18911891
)
18921892
await scheduler.run()
@@ -1899,7 +1899,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
18991899

19001900
@pytest.mark.asyncio(loop_scope="session")
19011901
async def test_scheduler_deadlock_regression() -> None:
1902-
"""max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
1902+
"""max_in_flight_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
19031903
provider = _mock_provider()
19041904
configs = [
19051905
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
@@ -1923,7 +1923,7 @@ async def test_scheduler_deadlock_regression() -> None:
19231923
graph=graph,
19241924
tracker=tracker,
19251925
row_groups=row_groups,
1926-
max_submitted_tasks=1,
1926+
max_in_flight_tasks=1,
19271927
max_model_task_admission=1,
19281928
)
19291929

@@ -2398,23 +2398,23 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None:
23982398
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
23992399
buffer_mgr = RowGroupBufferManager(storage)
24002400

2401-
max_submitted = 4
2401+
max_in_flight = 4
24022402
max_llm_wait = 2
24032403
scheduler = AsyncTaskScheduler(
24042404
generators=generators,
24052405
graph=graph,
24062406
tracker=tracker,
24072407
row_groups=row_groups,
24082408
buffer_manager=buffer_mgr,
2409-
max_submitted_tasks=max_submitted,
2409+
max_in_flight_tasks=max_in_flight,
24102410
max_model_task_admission=max_llm_wait,
24112411
)
24122412
await scheduler.run()
24132413

24142414
assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"])
24152415

24162416
snapshot = scheduler.task_admission_snapshot()
2417-
assert snapshot.resources_available["submission"] == max_submitted
2417+
assert snapshot.resources_available["submission"] == max_in_flight
24182418
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24192419

24202420

@@ -2460,15 +2460,15 @@ async def agenerate(self, data: dict) -> dict:
24602460
row_groups = [(0, 2)]
24612461
tracker = CompletionTracker.with_graph(graph, row_groups)
24622462

2463-
max_submitted = 4
2463+
max_in_flight = 4
24642464
max_llm_wait = 2
24652465
sink = InMemoryAdmissionEventSink()
24662466
scheduler = AsyncTaskScheduler(
24672467
generators=generators,
24682468
graph=graph,
24692469
tracker=tracker,
24702470
row_groups=row_groups,
2471-
max_submitted_tasks=max_submitted,
2471+
max_in_flight_tasks=max_in_flight,
24722472
max_model_task_admission=max_llm_wait,
24732473
scheduler_event_sink=sink,
24742474
)
@@ -2481,7 +2481,7 @@ async def agenerate(self, data: dict) -> dict:
24812481
await run_task
24822482

24832483
snapshot = scheduler.task_admission_snapshot()
2484-
assert snapshot.resources_available["submission"] == max_submitted
2484+
assert snapshot.resources_available["submission"] == max_in_flight
24852485
assert snapshot.resources_available["llm_wait"] == max_llm_wait
24862486
assert "cancelled" in [event.event_kind for event in sink.scheduler_events]
24872487
assert all(event.snapshot is not None for event in sink.scheduler_events)
@@ -2703,7 +2703,7 @@ async def test_scheduler_fair_admission_across_ready_columns() -> None:
27032703
graph=graph,
27042704
tracker=tracker,
27052705
row_groups=row_groups,
2706-
max_submitted_tasks=4,
2706+
max_in_flight_tasks=4,
27072707
trace=True,
27082708
)
27092709

@@ -2777,7 +2777,7 @@ async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
27772777
graph=graph,
27782778
tracker=tracker,
27792779
row_groups=row_groups,
2780-
max_submitted_tasks=8,
2780+
max_in_flight_tasks=8,
27812781
max_concurrent_row_groups=2,
27822782
trace=True,
27832783
)
@@ -2825,7 +2825,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None:
28252825
graph=graph,
28262826
tracker=tracker,
28272827
row_groups=row_groups,
2828-
max_submitted_tasks=4,
2828+
max_in_flight_tasks=4,
28292829
max_model_task_admission=4,
28302830
trace=True,
28312831
)
@@ -2896,7 +2896,7 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None:
28962896
tracker=tracker,
28972897
row_groups=row_groups,
28982898
buffer_manager=buffer_manager,
2899-
max_submitted_tasks=4,
2899+
max_in_flight_tasks=4,
29002900
trace=True,
29012901
)
29022902
await asyncio.wait_for(scheduler.run(), timeout=10.0)
@@ -2944,7 +2944,7 @@ async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None:
29442944
tracker=tracker,
29452945
row_groups=row_groups,
29462946
max_concurrent_row_groups=2,
2947-
max_submitted_tasks=2,
2947+
max_in_flight_tasks=2,
29482948
trace=True,
29492949
num_records=12,
29502950
buffer_size=3,
@@ -3042,7 +3042,7 @@ async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None:
30423042
tracker=tracker,
30433043
row_groups=row_groups,
30443044
max_concurrent_row_groups=1,
3045-
max_submitted_tasks=2,
3045+
max_in_flight_tasks=2,
30463046
max_model_task_admission=1,
30473047
scheduler_event_sink=sink,
30483048
num_records=2,
@@ -3108,7 +3108,7 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon
31083108
tracker=tracker,
31093109
row_groups=row_groups,
31103110
max_concurrent_row_groups=4,
3111-
max_submitted_tasks=4,
3111+
max_in_flight_tasks=4,
31123112
max_model_task_admission=4,
31133113
adaptive_row_group_admission=True,
31143114
adaptive_row_group_initial_target=1,
@@ -3208,6 +3208,17 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N
32083208
assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated"
32093209

32103210

3211+
def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None:
3212+
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
3213+
scheduler._max_in_flight_tasks = 2
3214+
scheduler._max_model_task_admission = 100
3215+
scheduler._fair_queue = SimpleNamespace(
3216+
view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={})
3217+
)
3218+
3219+
assert scheduler._adaptive_row_group_block_reason() == "queued_task_guardrail"
3220+
3221+
32113222
@pytest.mark.asyncio(loop_scope="session")
32123223
async def test_scheduler_raises_when_ready_frontier_blocked_without_in_flight() -> None:
32133224
provider = _mock_provider()

0 commit comments

Comments
 (0)