Skip to content

Commit 83ee424

Browse files
authored
fix: unblock adaptive scheduler bootstrap (#744)
1 parent b0076cd commit 83ee424

2 files changed

Lines changed: 161 additions & 6 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,9 +961,10 @@ def _adaptive_row_group_block_reason(self) -> str | None:
961961
return "no_llm_wait_resource"
962962
llm_available = task_view.resources_available.get("llm_wait", 0)
963963
queued_llm = queue_view.queued_peer_demand_by_resource.get("llm_wait", 0)
964+
llm_leased = task_view.leased_resources.get("llm_wait", 0)
964965
if llm_available <= 0:
965966
return "llm_wait_saturated"
966-
if llm_available <= queued_llm and queue_view.queued_total > 0:
967+
if llm_leased > 0 and llm_available <= queued_llm:
967968
return "queued_llm_demand"
968969
return None
969970

@@ -988,7 +989,12 @@ def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]:
988989
"admitted_rows": admitted_rows,
989990
"max_admitted_rows": self._adaptive_max_admitted_rows,
990991
"queued_total": queue_view.queued_total,
992+
"queued_demand_by_resource": dict(queue_view.queued_peer_demand_by_resource),
991993
"queued_llm_wait_demand": queue_view.queued_peer_demand_by_resource.get("llm_wait", 0),
994+
"in_flight_tasks": len(self._in_flight),
995+
"resource_limits": dict(task_view.resource_limits),
996+
"leased_resources": dict(task_view.leased_resources),
997+
"resources_available": dict(task_view.resources_available),
992998
"llm_wait_limit": task_view.resource_limits.get("llm_wait", 0),
993999
"llm_wait_leased": task_view.leased_resources.get("llm_wait", 0),
9941000
"llm_wait_available": task_view.resources_available.get("llm_wait", 0),

packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py

Lines changed: 154 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def _build_simple_pipeline(
375375
configs: list[SamplerColumnConfig | LLMTextColumnConfig | ExpressionColumnConfig] | None = None,
376376
strategies: dict[str, GenerationStrategy] | None = None,
377377
scheduler_event_sink: Any | None = None,
378+
max_concurrent_row_groups: int = 3,
379+
adaptive_row_group_admission: bool = False,
380+
adaptive_row_group_initial_target: int = 1,
378381
) -> tuple[AsyncTaskScheduler, CompletionTracker]:
379382
"""Build a simple seed → cell pipeline for testing."""
380383
if configs is None:
@@ -412,8 +415,11 @@ def _build_simple_pipeline(
412415
graph=graph,
413416
tracker=tracker,
414417
row_groups=row_groups,
418+
max_concurrent_row_groups=max_concurrent_row_groups,
415419
trace=trace,
416420
scheduler_event_sink=scheduler_event_sink,
421+
adaptive_row_group_admission=adaptive_row_group_admission,
422+
adaptive_row_group_initial_target=adaptive_row_group_initial_target,
417423
)
418424
return scheduler, tracker
419425

@@ -3787,6 +3793,31 @@ def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> N
37873793
assert scheduler._row_group_row_guard_allows(9_000)
37883794

37893795

3796+
def _stub_row_group_admission_resource_views(
3797+
scheduler: AsyncTaskScheduler,
3798+
*,
3799+
queued_total: int,
3800+
queued_llm: int,
3801+
llm_limit: int,
3802+
llm_available: int,
3803+
llm_leased: int,
3804+
) -> None:
3805+
scheduler._fair_queue = SimpleNamespace(
3806+
view=lambda: SimpleNamespace(
3807+
queued_total=queued_total,
3808+
queued_by_group={},
3809+
queued_peer_demand_by_resource={"llm_wait": queued_llm} if queued_llm else {},
3810+
)
3811+
)
3812+
scheduler._task_admission = SimpleNamespace(
3813+
view=lambda: SimpleNamespace(
3814+
resource_limits={"llm_wait": llm_limit},
3815+
resources_available={"llm_wait": llm_available},
3816+
leased_resources={"llm_wait": llm_leased} if llm_leased else {},
3817+
)
3818+
)
3819+
3820+
37903821
def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> None:
37913822
provider = _mock_provider()
37923823
configs = [
@@ -3816,16 +3847,134 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N
38163847
num_records=2,
38173848
buffer_size=1,
38183849
)
3819-
scheduler._fair_queue = SimpleNamespace(
3820-
view=lambda: SimpleNamespace(queued_total=1, queued_peer_demand_by_resource={})
3821-
)
3822-
scheduler._task_admission = SimpleNamespace(
3823-
view=lambda: SimpleNamespace(resource_limits={"llm_wait": 1}, resources_available={"llm_wait": 0})
3850+
_stub_row_group_admission_resource_views(
3851+
scheduler,
3852+
queued_total=1,
3853+
queued_llm=0,
3854+
llm_limit=1,
3855+
llm_available=0,
3856+
llm_leased=1,
38243857
)
38253858

38263859
assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated"
38273860

38283861

3862+
def test_scheduler_adaptive_row_group_block_reason_allows_zero_llm_lease_bootstrap() -> None:
3863+
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
3864+
_stub_row_group_admission_resource_views(
3865+
scheduler,
3866+
queued_total=3,
3867+
queued_llm=3,
3868+
llm_limit=3,
3869+
llm_available=3,
3870+
llm_leased=0,
3871+
)
3872+
3873+
assert scheduler._adaptive_row_group_block_reason() is None
3874+
3875+
3876+
def test_scheduler_adaptive_row_group_block_reason_blocks_queued_llm_demand_after_bootstrap() -> None:
3877+
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
3878+
_stub_row_group_admission_resource_views(
3879+
scheduler,
3880+
queued_total=3,
3881+
queued_llm=3,
3882+
llm_limit=4,
3883+
llm_available=3,
3884+
llm_leased=1,
3885+
)
3886+
3887+
assert scheduler._adaptive_row_group_block_reason() == "queued_llm_demand"
3888+
3889+
3890+
def test_scheduler_row_group_admission_diagnostics_include_resource_state_for_events() -> None:
3891+
class ExplodingRequestPressureProvider:
3892+
def snapshots(self) -> None:
3893+
raise AssertionError("request pressure diagnostics should not be captured")
3894+
3895+
def global_snapshots(self) -> None:
3896+
raise AssertionError("request pressure diagnostics should not be captured")
3897+
3898+
scheduler, _tracker = _build_simple_pipeline(
3899+
num_records=2,
3900+
buffer_size=1,
3901+
)
3902+
scheduler._request_pressure_provider = ExplodingRequestPressureProvider()
3903+
_stub_row_group_admission_resource_views(
3904+
scheduler,
3905+
queued_total=3,
3906+
queued_llm=3,
3907+
llm_limit=4,
3908+
llm_available=3,
3909+
llm_leased=1,
3910+
)
3911+
scheduler._in_flight.add(Task(column="cell_out", row_group=0, row_index=0, task_type="cell"))
3912+
3913+
diagnostics = scheduler._row_group_admission_diagnostics(reason="queued_llm_demand")
3914+
3915+
assert diagnostics["queued_demand_by_resource"] == {"llm_wait": 3}
3916+
assert diagnostics["resource_limits"] == {"llm_wait": 4}
3917+
assert diagnostics["leased_resources"] == {"llm_wait": 1}
3918+
assert diagnostics["resources_available"] == {"llm_wait": 3}
3919+
assert diagnostics["in_flight_tasks"] == 1
3920+
assert "request_pressure" not in diagnostics
3921+
3922+
3923+
def test_scheduler_adaptive_row_group_target_grows_for_zero_llm_lease_bootstrap() -> None:
3924+
scheduler, _tracker = _build_simple_pipeline(
3925+
num_records=2,
3926+
buffer_size=1,
3927+
scheduler_event_sink=(sink := InMemoryAdmissionEventSink()),
3928+
max_concurrent_row_groups=2,
3929+
adaptive_row_group_admission=True,
3930+
adaptive_row_group_initial_target=1,
3931+
)
3932+
scheduler._rg_states[0] = SimpleNamespace(size=1)
3933+
_stub_row_group_admission_resource_views(
3934+
scheduler,
3935+
queued_total=3,
3936+
queued_llm=3,
3937+
llm_limit=3,
3938+
llm_available=3,
3939+
llm_leased=0,
3940+
)
3941+
3942+
scheduler._maybe_update_adaptive_row_group_target()
3943+
assert scheduler._row_group_admission_target == 1
3944+
3945+
scheduler._maybe_update_adaptive_row_group_target()
3946+
3947+
assert scheduler._row_group_admission_target == 2
3948+
assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events)
3949+
3950+
3951+
def test_scheduler_adaptive_row_group_target_stays_blocked_after_llm_lease_bootstrap() -> None:
3952+
scheduler, _tracker = _build_simple_pipeline(
3953+
num_records=2,
3954+
buffer_size=1,
3955+
scheduler_event_sink=(sink := InMemoryAdmissionEventSink()),
3956+
max_concurrent_row_groups=2,
3957+
adaptive_row_group_admission=True,
3958+
adaptive_row_group_initial_target=1,
3959+
)
3960+
scheduler._rg_states[0] = SimpleNamespace(size=1)
3961+
_stub_row_group_admission_resource_views(
3962+
scheduler,
3963+
queued_total=3,
3964+
queued_llm=3,
3965+
llm_limit=4,
3966+
llm_available=3,
3967+
llm_leased=1,
3968+
)
3969+
3970+
scheduler._maybe_update_adaptive_row_group_target()
3971+
scheduler._maybe_update_adaptive_row_group_target()
3972+
3973+
assert scheduler._row_group_admission_target == 1
3974+
assert not any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events)
3975+
assert scheduler._row_group_admission_blocked_reasons["queued_llm_demand"] == 2
3976+
3977+
38293978
def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None:
38303979
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
38313980
scheduler._max_in_flight_tasks = 2

0 commit comments

Comments
 (0)