diff --git a/architecture/dataset-builders.md b/architecture/dataset-builders.md index c20eeedbc..825a2a392 100644 --- a/architecture/dataset-builders.md +++ b/architecture/dataset-builders.md @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`): 4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler` 5. Hooks `ProcessorRunner` for pre-batch and post-batch stages -`AsyncTaskScheduler` runs on a dedicated async loop with semaphore-based concurrency, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. +`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, semaphore-based capacity limits, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks are admitted through a virtual-time fair queue so one hot column or model-backed generator cannot consume the whole submission window before peer work gets a turn. ### Execution Graph @@ -123,7 +123,7 @@ DatasetBuilder.build() → CompletionTracker.with_graph() → AsyncTaskScheduler(semaphores, salvage_rounds) → scheduler.run() - → for each row group, dispatch ready tasks from frontier + → for each row group, fairly admit ready tasks from frontier → tasks execute generators, update CompletionTracker → checkpoints via RowGroupBufferManager → collect TaskTraces, emit telemetry @@ -133,6 +133,7 @@ DatasetBuilder.build() - **Dual execution engines behind one API.** The sequential engine is simpler and easier to debug; the async engine adds row-group parallelism for throughput. Users switch via an environment variable without changing their code. - **DAG-driven ordering** ensures columns with dependencies (e.g., a judge column that depends on a text column) are generated in the correct order, regardless of the order they appear in the config. +- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. Global semaphores still bound memory/coroutine growth, while per-group virtual-time queues prevent a large ready frontier from degenerating into a column-by-column wave. LLM admission caps are peer-sensitive: a solo model group can fill available global capacity, but once another scheduling group has queued work the saturated group yields until peers get admission slots or admitted tasks complete. - **Salvage rounds in async mode** retry failed tasks after all other tasks in a round complete, improving resilience against transient LLM failures without blocking the entire generation. - **Unified DAG construction.** `topologically_sort_column_configs` (in `execution_graph.py`) determines column ordering using Kahn's algorithm; the runtime `ExecutionGraph` adds strategy-aware dependency tracking for the async scheduler. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 688ec529b..778501da1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -7,7 +7,7 @@ import contextlib import logging import time -from collections import deque +from collections import defaultdict, deque from collections.abc import Coroutine from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable @@ -21,8 +21,14 @@ DEFAULT_REPORT_INTERVAL, AsyncProgressReporter, ) -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.utils.fair_task_queue import ( + FairTaskQueue, + TaskGroupKey, + TaskGroupSpec, +) from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker +from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record from data_designer.engine.dataset_builders.utils.skip_tracker import ( apply_skip_to_record, @@ -40,7 +46,10 @@ logger = logging.getLogger(__name__) DEFAULT_TASK_POOL_SIZE: int = 256 -LLM_WAIT_POOL_MULTIPLIER: int = 2 +# Global LLM wait-pool headroom sizes the memory-safety semaphore above provider capacity. +GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER: int = 2 +# Per-group admission backlog caps how many ready LLM tasks one fair-queue group can hold. +LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 # Degraded-provider WARN: emit at most one warning per interval when the # rolling fraction of retryable errors exceeds the threshold. Distinct from @@ -76,6 +85,15 @@ class _RowGroupState: in_flight_count: int = 0 +@dataclass(frozen=True) +class _DispatchOutcome: + """Result of one fair-dispatch pass over the persistent ready queue.""" + + dispatched: bool = False + submission_full: bool = False + group_blocked: bool = False + + class AsyncTaskScheduler: """Dependency-aware async task scheduler for the dataset builder. @@ -96,7 +114,7 @@ def __init__( max_llm_wait_tasks: int = DEFAULT_TASK_POOL_SIZE, salvage_max_rounds: int = 2, on_finalize_row_group: Callable[[int], None] | None = None, - on_seeds_complete: Callable[[int, int], None] | None = None, + on_seeds_complete: Callable[[int, int], FrontierDelta | None] | None = None, on_before_checkpoint: Callable[[int, int], None] | None = None, shutdown_error_rate: float = 0.5, shutdown_error_window: int = 10, @@ -119,8 +137,15 @@ def __init__( self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups) self._submission_semaphore = TrackingSemaphore(max_submitted_tasks) self._llm_wait_semaphore = TrackingSemaphore(max_llm_wait_tasks) + self._max_llm_wait_tasks = max_llm_wait_tasks self._llm_bound_lookup = build_llm_bound_lookup(generators) + self._scheduling_hints = SchedulingHintResolver(generators) + self._fair_queue = FairTaskQueue() + self._pending_pre_batch_ready: defaultdict[int, list[Task]] = defaultdict(list) + self._pending_pre_batch_ready_tasks: set[Task] = set() + # Task group specs are derived from per-generator scheduling hints and flow identity. + self._task_group_spec_cache: dict[int, TaskGroupSpec] = {} self._dispatched: set[Task] = set() self._in_flight: set[Task] = set() @@ -204,7 +229,7 @@ def __init__( self._rg_size_map: dict[int, int] = dict(row_groups) # Pre-compute seed columns (graph is static) - self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c)) + self._seed_cols: tuple[str, ...] = tuple(c for c in graph.columns if not graph.get_upstream_columns(c)) # Per-column progress tracking (cell-by-cell only; full-column tasks are instant) self._progress_bar = StickyProgressBar() if progress_bar else None @@ -283,6 +308,107 @@ async def _cancel_workers(self) -> None: await asyncio.gather(*self._worker_tasks, return_exceptions=True) self._worker_tasks.clear() + def _apply_frontier_delta(self, delta: FrontierDelta) -> None: + if delta.empty: + return + for task in delta.removed: + self._discard_ready_task(task) + for task in delta.added: + self._enqueue_ready_task(task) + + def _enqueue_ready_task(self, task: Task) -> None: + if task in self._dispatched or task.row_group not in self._rg_states: + return + if not self._tracker.is_frontier_task(task): + return + state = self._rg_states[task.row_group] + if self._on_seeds_complete is not None and not state.pre_batch_done: + if task not in self._pending_pre_batch_ready_tasks: + self._pending_pre_batch_ready[task.row_group].append(task) + self._pending_pre_batch_ready_tasks.add(task) + return + self._fair_queue.enqueue(task, self._task_group_spec(task)) + + def _discard_ready_task(self, task: Task) -> None: + self._fair_queue.discard(task) + self._pending_pre_batch_ready_tasks.discard(task) + + def _flush_pre_batch_ready(self, row_group: int) -> None: + pending = self._pending_pre_batch_ready.pop(row_group, []) + for task in pending: + if task not in self._pending_pre_batch_ready_tasks: + continue + self._pending_pre_batch_ready_tasks.discard(task) + self._enqueue_ready_task(task) + + def _drop_pending_ready_for_row_group(self, row_group: int) -> None: + pending = self._pending_pre_batch_ready.pop(row_group, []) + for task in pending: + self._pending_pre_batch_ready_tasks.discard(task) + self._fair_queue.discard_where(lambda task: task.row_group == row_group) + + def _dispatch_queued_tasks(self) -> _DispatchOutcome: + dispatched = False + + while self._fair_queue.has_queued_tasks: + if not self._submission_semaphore.try_acquire(): + return _DispatchOutcome(dispatched=dispatched, submission_full=True) + + selection = self._fair_queue.admit_next() + if selection is None: + self._submission_semaphore.release() + return _DispatchOutcome(dispatched=dispatched, group_blocked=True) + + self._dispatch_selected_task(selection.task) + dispatched = True + + return _DispatchOutcome(dispatched=dispatched) + + def _dispatch_selected_task(self, task: Task) -> None: + self._dispatched.add(task) + self._in_flight.add(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 + self._spawn_worker(self._execute_task(task)) + + def _task_group_spec(self, task: Task) -> TaskGroupSpec: + generator = self._generators[task.column] + generator_id = id(generator) + cached = self._task_group_spec_cache.get(generator_id) + if cached is not None: + return cached + + spec = self._task_group_spec_from_hint( + self._scheduling_hints.hint_for(generator), + self._task_flow_identity(task), + ) + self._task_group_spec_cache[generator_id] = spec + return spec + + def _task_group_spec_from_hint(self, hint: SchedulingHint, flow_identity: tuple[str, ...]) -> TaskGroupSpec: + if hint.group_kind == "local": + return TaskGroupSpec(key=TaskGroupKey(kind="local", identity=flow_identity)) + + if hint.group_kind == "custom_model": + identity = (*flow_identity, *hint.identity_suffix) + else: + identity = (*hint.identity_prefix, *flow_identity, *hint.identity_suffix) + + weight = max(1, hint.weight) + return TaskGroupSpec( + key=TaskGroupKey(kind=hint.group_kind, identity=identity), + weight=float(weight), + admitted_limit=self._llm_group_admitted_limit(weight), + ) + + def _task_flow_identity(self, task: Task) -> tuple[str, ...]: + generator = self._generators[task.column] + output_columns = self._gen_instance_to_columns.get(id(generator), [task.column]) + return tuple(output_columns) + + def _llm_group_admitted_limit(self, weight: int) -> int: + return max(1, min(self._max_llm_wait_tasks, LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER * weight)) + async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" for rg_id, rg_size in self._row_groups: @@ -349,7 +475,7 @@ async def run(self) -> None: async def _main_dispatch_loop( self, - seed_cols: frozenset[str], + seed_cols: tuple[str, ...], has_pre_batch: bool, all_columns: list[str], ) -> None: @@ -367,25 +493,7 @@ async def _main_dispatch_loop( if has_pre_batch: self._run_seeds_complete_check(seed_cols) - admitted_ids = set(self._rg_states) - ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) - # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured - if has_pre_batch: - ready = [ - t - for t in ready - if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols - ] - semaphore_full = False - for task in ready: - if not self._submission_semaphore.try_acquire(): - semaphore_full = True - break - self._dispatched.add(task) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_task(task)) + dispatch_outcome = self._dispatch_queued_tasks() self._checkpoint_completed_row_groups(all_columns) @@ -400,16 +508,20 @@ async def _main_dispatch_loop( if all_done: break - if not ready and not self._in_flight: + if not self._fair_queue.has_queued_tasks and not self._in_flight: if self._all_rgs_admitted: break - if not ready or semaphore_full: + if ( + not self._fair_queue.has_queued_tasks + or dispatch_outcome.submission_full + or dispatch_outcome.group_blocked + ): await self._wake_event.wait() async def _salvage_rounds( self, - seed_cols: frozenset[str], + seed_cols: tuple[str, ...], has_pre_batch: bool, all_columns: list[str], ) -> None: @@ -464,34 +576,25 @@ async def _salvage_rounds( self._spawn_worker(self._execute_seed_task(task, gid)) else: self._dispatched.discard(task) + self._enqueue_ready_task(task) # Drain: dispatch frontier tasks and any newly-ready downstream tasks # until nothing remains in-flight or in the frontier. - await self._drain_frontier(seed_cols, has_pre_batch, all_columns) + await self._drain_frontier(seed_cols, has_pre_batch) self._checkpoint_completed_row_groups(all_columns) - async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None: + async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: if has_pre_batch: self._run_seeds_complete_check(seed_cols) - admitted_ids = set(self._rg_states) - ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) - if has_pre_batch: - ready = [ - t - for t in ready - if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols - ] - for task in ready: - if not self._submission_semaphore.try_acquire(): - break - self._dispatched.add(task) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_task(task)) - if not ready and not self._in_flight: + dispatch_outcome = self._dispatch_queued_tasks() + has_queued = self._fair_queue.has_queued_tasks + if not has_queued and not self._in_flight: break + if has_queued and not dispatch_outcome.dispatched and not self._in_flight: + raise RuntimeError( + "Ready frontier is admission-blocked with no in-flight task to release scheduler capacity." + ) if not self._in_flight: continue self._wake_event.clear() @@ -499,7 +602,7 @@ async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, async def _salvage_stalled_row_groups( self, - seed_cols: frozenset[str], + seed_cols: tuple[str, ...], has_pre_batch: bool, all_columns: list[str], ) -> None: @@ -583,6 +686,8 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: if completed: checkpointed = {rg_id for rg_id, _ in completed} self._deferred = [t for t in self._deferred if t.row_group not in checkpointed] + for rg_id in checkpointed: + self._drop_pending_ready_for_row_group(rg_id) def _finalize_after_shutdown(self, all_columns: list[str]) -> None: """Salvage row groups left in flight when early shutdown fired. @@ -622,7 +727,7 @@ def _finalize_after_shutdown(self, all_columns: list[str]) -> None: logger.warning(f"Row group {rg_id}: 0 of {rg_size} rows survived early shutdown - skipping write.") self._checkpoint_completed_row_groups(all_columns) - def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: + def _run_seeds_complete_check(self, seed_cols: tuple[str, ...]) -> None: """Run pre-batch callbacks for row groups whose seeds just completed.""" for rg_id, state in list(self._rg_states.items()): if state.seeds_dispatched and not state.pre_batch_done: @@ -631,7 +736,7 @@ def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: state.pre_batch_done = True if self._on_seeds_complete: try: - self._on_seeds_complete(rg_id, state.size) + delta = self._on_seeds_complete(rg_id, state.size) except DatasetGenerationError: raise except Exception as exc: @@ -645,13 +750,16 @@ def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: for ri in range(state.size): if self._tracker.is_dropped(rg_id, ri): self._record_skipped_tasks_for_row(rg_id, ri) + if delta is not None: + self._apply_frontier_delta(delta) + self._flush_pre_batch_ready(rg_id) def _drop_row(self, row_group: int, row_index: int, *, exclude_columns: set[str] | None = None) -> None: if self._tracker.is_dropped(row_group, row_index): return self._record_skipped_tasks_for_row(row_group, row_index, exclude_columns=exclude_columns) - self._tracker.drop_row(row_group, row_index) + self._apply_frontier_delta(self._tracker.drop_row(row_group, row_index)) if self._buffer_manager: self._buffer_manager.drop_row(row_group, row_index) @@ -742,12 +850,14 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: seen_instances.add(gid) task = Task(column=col, row_group=rg_id, row_index=None, task_type="from_scratch") - # Also mark the "batch" variant as dispatched to prevent get_ready_tasks - # from generating a duplicate for this column + # Also mark the "batch" variant as dispatched to prevent duplicate + # scheduling for this column. batch_alias = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") if task in self._dispatched or batch_alias in self._dispatched: continue + # Seeds bypass fair-queue admission while row groups are being admitted; + # direct dispatch preserves stateful lock ordering across row groups. # Acquire stateful lock *before* submission semaphore to preserve # row-group ordering. Held until generation completes (_execute_seed_task). if gid in self._stateful_locks: @@ -842,9 +952,10 @@ async def _execute_task_inner_impl(self, task: Task) -> None: for col in output_cols: if task.row_index is None: rg_size = self._get_rg_size(task.row_group) - self._tracker.mark_row_range_complete(col, task.row_group, rg_size) + delta = self._tracker.mark_row_range_complete(col, task.row_group, rg_size) else: - self._tracker.mark_cell_complete(col, task.row_group, task.row_index) + delta = self._tracker.mark_cell_complete(col, task.row_group, task.row_index) + self._apply_frontier_delta(delta) self._check_error_rate(success=True) # The degraded-provider WARN is provider-scoped: only feed the @@ -901,6 +1012,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None: trace.completed_at = time.perf_counter() self.traces.append(trace) + self._fair_queue.release(task) self._in_flight.discard(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count = max(0, s.in_flight_count - 1) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 3b15ed96d..b820c95aa 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -85,14 +85,14 @@ from data_designer.engine.dataset_builders.async_scheduler import ( DEFAULT_TASK_POOL_SIZE, - LLM_WAIT_POOL_MULTIPLIER, + GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER, AsyncTaskScheduler, ) from data_designer.engine.dataset_builders.utils.async_concurrency import ( AsyncConcurrentExecutor, ensure_async_engine_loop, ) - from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker + from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -996,13 +996,18 @@ def _prepare_async_run( # Pre-batch processor callback: runs after seed tasks complete for a row group. # If it raises, the scheduler propagates the error as DatasetGenerationError (fail-fast). - def on_seeds_complete(rg_id: int, rg_size: int) -> None: + def on_seeds_complete(rg_id: int, rg_size: int) -> FrontierDelta: df = buffer_manager.get_dataframe(rg_id) df = self._processor_runner.run_pre_batch_on_df(df, strict_row_count=True) buffer_manager.replace_dataframe(rg_id, df) + deltas: list[FrontierDelta] = [] for ri in range(rg_size): if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri): - tracker.drop_row(rg_id, ri) + deltas.append(tracker.drop_row(rg_id, ri)) + return FrontierDelta( + added=tuple(task for delta in deltas for task in delta.added), + removed=tuple(task for delta in deltas for task in delta.removed), + ) # Post-batch processor callback: runs after all columns, before finalization. def on_before_checkpoint(rg_id: int, rg_size: int) -> None: @@ -1022,7 +1027,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: row_groups=row_groups, buffer_manager=buffer_manager, max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, - max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate), + max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER * aggregate), on_finalize_row_group=on_finalize_row_group, on_seeds_complete=( on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 881bf6c54..2d35ec0be 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy @@ -13,6 +14,18 @@ from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +@dataclass(frozen=True) +class FrontierDelta: + """Tasks added to or removed from the ready frontier by a tracker mutation.""" + + added: tuple[Task, ...] = () + removed: tuple[Task, ...] = () + + @property + def empty(self) -> bool: + return not self.added and not self.removed + + class CompletionTracker: """Tracks which cells (column, row_group, row_index) are done. @@ -42,24 +55,34 @@ def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> tracker._row_group_sizes = dict(row_groups) return tracker - def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: + def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> FrontierDelta: self._validate_row_group(row_group) self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete") self._completed[row_group][column].add(row_index) + removed: list[Task] = [] + added: list[Task] = [] if self._graph is not None: - self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) - self._enqueue_downstream(column, row_group, row_index=row_index) + task = Task(column=column, row_group=row_group, row_index=row_index, task_type="cell") + if self._discard_frontier_task(task): + removed.append(task) + added.extend(self._enqueue_downstream(column, row_group, row_index=row_index)) + return self._record_delta(added=added, removed=removed) - def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None: + def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> FrontierDelta: expected = self._validate_row_group(row_group) self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete") if expected is not None and row_group_size != expected: raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") self._completed[row_group][column] = set(range(row_group_size)) self._batch_complete[row_group].add(column) + removed: list[Task] = [] + added: list[Task] = [] if self._graph is not None: - self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) - self._enqueue_downstream(column, row_group, row_index=None) + task = Task(column=column, row_group=row_group, row_index=None, task_type="batch") + if self._discard_frontier_task(task): + removed.append(task) + added.extend(self._enqueue_downstream(column, row_group, row_index=None)) + return self._record_delta(added=added, removed=removed) def is_complete(self, ref: SliceRef) -> bool: return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) @@ -89,15 +112,20 @@ def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: dropped = self._dropped.get(row_group_index, set()) return all(ri in completed or ri in dropped for ri in range(rg_size)) - def drop_row(self, row_group: int, row_index: int) -> None: + def drop_row(self, row_group: int, row_index: int) -> FrontierDelta: self._validate_row_group(row_group) self._dropped[row_group].add(row_index) + removed: list[Task] = [] + added: list[Task] = [] if self._graph is not None: # Remove cell tasks for this row from the frontier for col in self._graph.columns: - self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) + task = Task(column=col, row_group=row_group, row_index=row_index, task_type="cell") + if self._discard_frontier_task(task): + removed.append(task) # Dropping a row may unblock batch downstream tasks - self._reevaluate_batch_tasks(row_group) + added.extend(self._reevaluate_batch_tasks(row_group)) + return self._record_delta(added=added, removed=removed) def is_dropped(self, row_group: int, row_index: int) -> bool: return row_index in self._dropped.get(row_group, set()) @@ -129,6 +157,10 @@ def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = t for t in self._frontier if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) ] + def is_frontier_task(self, task: Task) -> bool: + """Return whether *task* is still in the ready frontier.""" + return task in self._frontier + def seed_frontier(self) -> None: """Populate the frontier with root tasks (columns with no upstream deps). @@ -147,10 +179,26 @@ def seed_frontier(self) -> None: else: self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) - def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None: + def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta: + return FrontierDelta(added=tuple(added), removed=tuple(removed)) + + def _add_frontier_task(self, task: Task) -> bool: + if task in self._frontier: + return False + self._frontier.add(task) + return True + + def _discard_frontier_task(self, task: Task) -> bool: + if task not in self._frontier: + return False + self._frontier.remove(task) + return True + + def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> list[Task]: """Add newly-ready downstream tasks to the frontier.""" if self._graph is None: raise RuntimeError("This method requires a graph to be set.") + added: list[Task] = [] rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) @@ -175,7 +223,8 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None and all(row_index in s for s in cell_up_completed) ): task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell") - self._frontier.add(task) + if self._add_frontier_task(task): + added.append(task) else: # Batch completion: check all non-dropped, non-complete rows down_completed = rg_completed.get(down, set()) @@ -184,19 +233,23 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None continue if all(ri in s for s in cell_up_completed): task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell") - self._frontier.add(task) + if self._add_frontier_task(task): + added.append(task) else: # FULL_COLUMN downstream: ready when all cell upstreams are fully complete if down not in rg_batch_complete and self._are_cell_ups_complete( cell_ups, rg_completed, rg_size, rg_dropped ): task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") - self._frontier.add(task) + if self._add_frontier_task(task): + added.append(task) + return added - def _reevaluate_batch_tasks(self, row_group: int) -> None: + def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]: """Check if any batch tasks became ready after a row was dropped.""" if self._graph is None: raise RuntimeError("This method requires a graph to be set.") + added: list[Task] = [] rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) @@ -212,7 +265,9 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None: continue if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") - self._frontier.add(task) + if self._add_frontier_task(task): + added.append(task) + return added def _are_cell_ups_complete( self, diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py new file mode 100644 index 000000000..32301b767 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import heapq +from collections import deque +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +from data_designer.engine.dataset_builders.utils.task_model import Task + + +@dataclass(frozen=True, order=True) +class TaskGroupKey: + """Stable identity for a stream of related scheduler tasks.""" + + kind: Literal["model", "custom_model", "local"] + identity: tuple[str, ...] + + +@dataclass(frozen=True) +class TaskGroupSpec: + """Scheduling metadata for a task group.""" + + key: TaskGroupKey + weight: float = 1.0 + admitted_limit: int | None = None + + +@dataclass(frozen=True) +class TaskSelection: + """A task selected for dispatch with the group metadata used to choose it.""" + + task: Task + group: TaskGroupSpec + + +class FairTaskQueue: + """Virtual-time fair queue with peer-sensitive per-group FIFO admission limits.""" + + def __init__(self) -> None: + self._queues: dict[TaskGroupKey, deque[Task]] = {} + self._queued: set[Task] = set() + self._task_groups: dict[Task, TaskGroupKey] = {} + self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + self._group_finish: dict[TaskGroupKey, float] = {} + self._admitted_by_group: dict[TaskGroupKey, int] = {} + self._admitted_task_groups: dict[Task, TaskGroupKey] = {} + self._heap: list[tuple[float, int, TaskGroupKey]] = [] + self._active_heap_keys: set[TaskGroupKey] = set() + self._sequence = 0 + self._virtual_time = 0.0 + + @property + def has_queued_tasks(self) -> bool: + return bool(self._queued) + + def enqueue(self, task: Task, group: TaskGroupSpec) -> None: + """Add one ready task to its fair scheduling group.""" + self._group_specs[group.key] = group + if task in self._queued: + return + queue = self._queues.setdefault(group.key, deque()) + queue.append(task) + self._queued.add(task) + self._task_groups[task] = group.key + self._activate_group(group.key) + + def discard(self, task: Task) -> None: + """Remove a queued task lazily if it is no longer dispatchable.""" + self._queued.discard(task) + self._task_groups.pop(task, None) + + def discard_where(self, predicate: Callable[[Task], bool]) -> None: + """Remove queued tasks matching a predicate.""" + for task in tuple(self._queued): + if predicate(task): + self.discard(task) + + def admit_next(self) -> TaskSelection | None: + """Admit the next eligible task, or ``None`` if no queued group can run.""" + blocked: list[TaskGroupKey] = [] + try: + while self._heap: + finish, _, key = heapq.heappop(self._heap) + self._active_heap_keys.discard(key) + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue: + continue + if not self._can_admit_group(key): + blocked.append(key) + continue + + task = queue.popleft() + self._queued.discard(task) + self._task_groups.pop(task, None) + self._admitted_task_groups[task] = key + self._admitted_by_group[key] = self._admitted_by_group.get(key, 0) + 1 + + group = self._group_specs[key] + self._virtual_time = max(self._virtual_time, finish) + self._group_finish[key] = self._virtual_time + (1.0 / max(group.weight, 1.0)) + self._purge_queue_head(key) + if queue: + self._activate_group(key) + return TaskSelection(task=task, group=group) + return None + finally: + for key in blocked: + self._activate_group(key) + + def release(self, task: Task) -> None: + """Release one previously admitted task from its group limit.""" + key = self._admitted_task_groups.pop(task, None) + if key is None: + return + admitted = self._admitted_by_group.get(key, 0) + if admitted <= 1: + self._admitted_by_group.pop(key, None) + else: + self._admitted_by_group[key] = admitted - 1 + self._activate_group(key) + + def _activate_group(self, key: TaskGroupKey) -> None: + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue or key in self._active_heap_keys: + return + self._sequence += 1 + finish = self._group_finish.get(key, self._virtual_time) + heapq.heappush(self._heap, (finish, self._sequence, key)) + self._active_heap_keys.add(key) + + def _purge_queue_head(self, key: TaskGroupKey) -> None: + queue = self._queues.get(key) + if queue is None: + return + while queue: + task = queue[0] + if task in self._queued and self._task_groups.get(task) == key: + break + queue.popleft() + + def _can_admit_group(self, key: TaskGroupKey) -> bool: + group = self._group_specs[key] + if group.admitted_limit is None: + return True + if self._admitted_by_group.get(key, 0) < group.admitted_limit: + return True + return not self._has_queued_peer_group(key) + + def _has_queued_peer_group(self, key: TaskGroupKey) -> bool: + return any(queued_key != key for queued_key in self._task_groups.values()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py new file mode 100644 index 000000000..dea66eeda --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from data_designer.engine.column_generators.generators.base import ColumnGenerator + +logger = logging.getLogger(__name__) + +SchedulingGroupKind = Literal["local", "model", "custom_model"] + + +@dataclass(frozen=True) +class SchedulingHint: + """Resolved task-scheduling metadata independent of graph flow identity.""" + + group_kind: SchedulingGroupKind + identity_prefix: tuple[str, ...] = () + identity_suffix: tuple[str, ...] = () + weight: int = 1 + + +class SchedulingHintResolver: + """Resolve generator/config/model metadata once for a scheduler run.""" + + def __init__(self, generators: dict[str, ColumnGenerator]) -> None: + self._hints_by_generator_id: dict[int, SchedulingHint] = {} + for column, generator in generators.items(): + generator_id = id(generator) + if generator_id not in self._hints_by_generator_id: + self._hints_by_generator_id[generator_id] = self._resolve_hint(column, generator) + + def hint_for(self, generator: ColumnGenerator) -> SchedulingHint: + return self._hints_by_generator_id[id(generator)] + + def _resolve_hint(self, column: str, generator: ColumnGenerator) -> SchedulingHint: + if not generator.is_llm_bound: + return SchedulingHint(group_kind="local") + + aliases = _model_aliases_for_generator(generator) + if not aliases: + return SchedulingHint(group_kind="model", identity_prefix=("unknown",), weight=1) + + model_parts: list[str] = [] + total_parallel = 0 + primary_alias = getattr(generator.config, "model_alias", None) + for alias in aliases: + try: + model_config = _get_model_config_for_alias(generator, alias) + provider_name = _get_model_provider_name_for_alias(generator, alias) + except Exception: + logger.debug( + "Falling back to custom-model scheduling group for column %r after failing to resolve " + "model alias %r from aliases %r.", + column, + alias, + aliases, + exc_info=True, + ) + return SchedulingHint( + group_kind="custom_model", + identity_suffix=tuple(sorted(aliases)), + weight=max(1, total_parallel), + ) + + max_parallel = getattr(model_config.inference_parameters, "max_parallel_requests", 1) + if not isinstance(max_parallel, int): + max_parallel = 1 + model_parts.extend( + ( + provider_name, + str(model_config.model), + str(model_config.generation_type), + alias, + ) + ) + total_parallel += max_parallel + + weight = max(1, total_parallel) + if len(aliases) == 1 and primary_alias == aliases[0]: + return SchedulingHint( + group_kind="model", + identity_prefix=tuple(model_parts[:3]), + weight=weight, + ) + + return SchedulingHint( + group_kind="custom_model", + identity_suffix=tuple(sorted(aliases)), + weight=weight, + ) + + +def _get_model_config_for_alias(generator: ColumnGenerator, alias: str) -> Any: + get_model_config = getattr(generator, "get_model_config", None) + if callable(get_model_config): + return get_model_config(model_alias=alias) + return generator.resource_provider.model_registry.get_model_config(model_alias=alias) + + +def _get_model_provider_name_for_alias(generator: ColumnGenerator, alias: str) -> str: + get_provider_name = getattr(generator, "get_model_provider_name", None) + if callable(get_provider_name): + return str(get_provider_name(model_alias=alias)) + provider = generator.resource_provider.model_registry.get_model_provider(model_alias=alias) + return str(provider.name) + + +def _model_aliases_for_generator(generator: ColumnGenerator) -> list[str]: + get_aliases = getattr(generator.config, "get_model_aliases", None) + if callable(get_aliases): + aliases = get_aliases() + else: + aliases = [] + if (alias := getattr(generator.config, "model_alias", None)) is not None: + aliases.append(alias) + aliases.extend(getattr(generator.config, "model_aliases", []) or []) + return list(dict.fromkeys(alias for alias in aliases if alias)) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index dab109dbb..684c009ba 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -24,7 +24,7 @@ ) from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager from data_designer.engine.resources.resource_provider import ResourceProvider @@ -304,11 +304,16 @@ async def test_dropped_rows_reduce_actual_record_count() -> None: buffer_manager = RowGroupBufferManager(storage) - def drop_all_in_rg1(rg_id: int, rg_size: int) -> None: + def drop_all_in_rg1(rg_id: int, rg_size: int) -> FrontierDelta: + deltas: list[FrontierDelta] = [] if rg_id == 1: for ri in range(rg_size): - tracker.drop_row(rg_id, ri) + deltas.append(tracker.drop_row(rg_id, ri)) buffer_manager.drop_row(rg_id, ri) + return FrontierDelta( + added=tuple(task for delta in deltas for task in delta.added), + removed=tuple(task for delta in deltas for task in delta.removed), + ) scheduler = AsyncTaskScheduler( generators=gen_map, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index fe536957c..6097232ef 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -5,6 +5,7 @@ import asyncio from collections.abc import Callable +from types import SimpleNamespace from typing import Any from unittest.mock import MagicMock @@ -20,6 +21,7 @@ SamplerColumnConfig, ) from data_designer.config.custom_column import custom_column_generator +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.config.sampler_params import SamplerType from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, @@ -29,9 +31,10 @@ from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler, build_llm_bound_lookup from data_designer.engine.dataset_builders.errors import DatasetGenerationError -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager +from data_designer.engine.dataset_builders.utils.task_model import Task from data_designer.engine.models.errors import ( RETRYABLE_MODEL_ERRORS, ModelInternalServerError, @@ -784,7 +787,7 @@ def fail_pre_batch(row_group: int, row_group_size: int) -> None: @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_error_rate_shutdown() -> None: +async def test_scheduler_error_rate_shutdown(caplog: pytest.LogCaptureFixture) -> None: """Early shutdown triggers when error rate exceeds threshold.""" provider = _mock_provider() configs = [ @@ -820,12 +823,13 @@ async def test_scheduler_error_rate_shutdown() -> None: shutdown_error_rate=0.5, shutdown_error_window=2, ) - await scheduler.run() + with caplog.at_level("ERROR", logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() # Early shutdown: not all rows should be checkpointed (some row groups incomplete) + assert scheduler.early_shutdown assert buffer_mgr.actual_num_records < 10 - # No leftover unfinished row groups (finalize-after-shutdown drains them). - assert not scheduler._rg_states + assert not any("unfinished row group" in record.getMessage() for record in caplog.records) @pytest.mark.asyncio(loop_scope="session") @@ -1412,6 +1416,29 @@ def generate(self, data: dict) -> dict: return data +class MockConfiguredModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): + """Mock cell generator with model-registry helpers.""" + + @property + def is_llm_bound(self) -> bool: + return True + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data[self.config.name] = f"model_{data.get('seed', '?')}" + return data + + def get_model_config(self, model_alias: str) -> ModelConfig: + return self.resource_provider.model_registry.get_model_config(model_alias=model_alias) + + def get_model_provider_name(self, model_alias: str) -> str: + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return str(provider.name) + + class MockLLMBoundRateLimitGenerator(ColumnGenerator[ExpressionColumnConfig]): """LLM-bound generator that raises ModelRateLimitError for the first N calls, then succeeds.""" @@ -1554,6 +1581,120 @@ async def test_scheduler_deadlock_regression() -> None: assert tracker.is_row_group_complete(0, 2, ["seed", "llm_col"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_drain_frontier_raises_when_ready_but_no_capacity_or_inflight() -> None: + """A broken admission state fails fast instead of spinning in the drain loop. + + This intentionally calls private frontier helpers: the state is an invariant + violation that public ``run()`` should never construct, but the fail-fast + guard prevents infinite waits if future scheduler changes create it. + """ + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + seed_delta = tracker.mark_row_range_complete("seed", 0, 1) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_submitted_tasks=0, + ) + scheduler._rg_states[0] = MagicMock(size=1) + scheduler._apply_frontier_delta(seed_delta) + + with pytest.raises(RuntimeError, match="Ready frontier is admission-blocked"): + await scheduler._drain_frontier(("seed",), False) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_dispatch_does_not_scan_ready_frontier(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 3)]) + + def fail_get_ready_tasks(*args: Any, **kwargs: Any) -> list[Task]: + raise AssertionError("scheduler should apply returned frontier deltas instead of scanning ready tasks") + + monkeypatch.setattr(tracker, "get_ready_tasks", fail_get_ready_tasks) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 3)], + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_pre_batch_drop_removes_pending_ready_task() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 3)]) + + def drop_middle_row(row_group: int, row_group_size: int) -> FrontierDelta: + del row_group_size + return tracker.drop_row(row_group, 1) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 3)], + on_seeds_complete=drop_middle_row, + trace=True, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + cell_traces = [trace for trace in scheduler.traces if trace.column == "cell_out"] + assert {trace.row_index for trace in cell_traces} == {0, 2} + assert tracker.is_dropped(0, 1) + assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_is_llm_bound_property_drives_lookup() -> None: """is_llm_bound property on generators drives the lookup, not isinstance.""" @@ -1595,6 +1736,155 @@ def gen_no_models(row: dict) -> dict: assert lookup == {"custom_llm": True, "custom_plain": False} +def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: + provider = MagicMock() + provider.model_registry = MagicMock() + provider.model_registry.get_model_config.side_effect = lambda model_alias: configs[model_alias] + provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") + return provider + + +def test_scheduler_model_task_group_spec_uses_model_resource_and_flow() -> None: + """Direct spec coverage keeps model identity and flow composition deterministic.""" + model_config = ModelConfig( + alias=MODEL_ALIAS, + model="model-text", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), + provider="mock-provider", + ) + provider = _provider_with_model_configs({MODEL_ALIAS: model_config}) + column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) + generator = MockConfiguredModelCellGenerator(config=column_config, resource_provider=provider) + graph = ExecutionGraph.create([column_config], {"answer": GenerationStrategy.CELL_BY_CELL}) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + scheduler = AsyncTaskScheduler( + generators={"answer": generator}, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + max_llm_wait_tasks=5, + ) + + spec = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) + + assert spec.key.kind == "model" + assert spec.key.identity[:2] == ("mock-provider", "model-text") + assert spec.key.identity[-1] == "answer" + assert spec.weight == 3.0 + assert spec.admitted_limit == 5 + + +def test_scheduler_task_group_spec_is_cached_per_generator() -> None: + """The per-generator spec cache has no stable public signal, so isolate it directly.""" + model_config = ModelConfig( + alias=MODEL_ALIAS, + model="model-text", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), + provider="mock-provider", + ) + provider = _provider_with_model_configs({MODEL_ALIAS: model_config}) + column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) + generator = MockConfiguredModelCellGenerator(config=column_config, resource_provider=provider) + graph = ExecutionGraph.create([column_config], {"answer": GenerationStrategy.CELL_BY_CELL}) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + scheduler = AsyncTaskScheduler( + generators={"answer": generator}, + graph=graph, + tracker=tracker, + row_groups=[(0, 2)], + max_llm_wait_tasks=5, + ) + + spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) + spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) + + assert spec_a is spec_b + assert provider.model_registry.get_model_config.call_count == 1 + assert provider.model_registry.get_model_provider.call_count == 1 + + +def test_scheduler_task_group_spec_logs_debug_on_model_resolution_fallback( + caplog: pytest.LogCaptureFixture, +) -> None: + """Direct spec resolution isolates fallback logging without timing-based scheduler traces.""" + provider = MagicMock() + provider.model_registry = MagicMock() + provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") + provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") + column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) + generator = MockConfiguredModelCellGenerator(config=column_config, resource_provider=provider) + graph = ExecutionGraph.create([column_config], {"answer": GenerationStrategy.CELL_BY_CELL}) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + + with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): + scheduler = AsyncTaskScheduler( + generators={"answer": generator}, + graph=graph, + tracker=tracker, + row_groups=[(0, 2)], + max_llm_wait_tasks=5, + ) + spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) + spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) + + assert spec_a is spec_b + assert spec_a.key.kind == "custom_model" + assert spec_a.key.identity == ("answer", MODEL_ALIAS) + assert spec_a.weight == 1.0 + assert provider.model_registry.get_model_config.call_count == 1 + fallback_records = [ + record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() + ] + assert len(fallback_records) == 1 + assert "answer" in fallback_records[0].getMessage() + assert MODEL_ALIAS in fallback_records[0].getMessage() + assert fallback_records[0].exc_info is not None + + +def test_scheduler_custom_model_task_group_spec_uses_alias_set_weight() -> None: + """Direct spec coverage verifies custom-model alias aggregation before fair admission.""" + + @custom_column_generator(model_aliases=["draft", "judge"]) + def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: + row["custom_llm"] = "val" + return row + + provider = _provider_with_model_configs( + { + "draft": ModelConfig( + alias="draft", + model="model-draft", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=2), + provider="mock-provider", + ), + "judge": ModelConfig( + alias="judge", + model="model-judge", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), + provider="mock-provider", + ), + } + ) + config = CustomColumnConfig(name="custom_llm", generator_function=gen_with_models) + generator = CustomColumnGenerator(config=config, resource_provider=provider) + graph = ExecutionGraph.create([config], {"custom_llm": GenerationStrategy.CELL_BY_CELL}) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + scheduler = AsyncTaskScheduler( + generators={"custom_llm": generator}, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + max_llm_wait_tasks=10, + ) + + spec = scheduler._task_group_spec(Task(column="custom_llm", row_group=0, row_index=0, task_type="cell")) + + assert spec.key.kind == "custom_model" + assert spec.key.identity == ("custom_llm", "draft", "judge") + assert spec.weight == 5.0 + assert spec.admitted_limit == 10 + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_llm_bound_429_retried_in_salvage() -> None: """A 429'd LLM-bound task is deferred, retried in salvage (handoff runs twice), and completes.""" @@ -1856,6 +2146,181 @@ async def agenerate(self, data: dict) -> dict: return self.generate(data) +class SlowLLMBoundCellGenerator(SlowCellGenerator): + """Slow cell generator that participates in LLM-wait scheduling.""" + + @property + def is_llm_bound(self) -> bool: + return True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_fair_admission_across_ready_columns() -> None: + """A large ready frontier is admitted across columns instead of one column at a time.""" + provider = _mock_provider() + gen_names = ["gen_a", "gen_b", "gen_c"] + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({c: GenerationStrategy.CELL_BY_CELL for c in gen_names}) + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + **{ + name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.05) + for name in gen_names + }, + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 12)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_submitted_tasks=4, + trace=True, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + first_window = [ + trace.column + for trace in sorted((t for t in scheduler.traces if t.column in gen_names), key=lambda t: t.dispatched_at)[:4] + ] + + assert set(first_window[:3]) == set(gen_names) + assert max(first_window.count(column) for column in gen_names) <= 2 + assert tracker.is_row_group_complete(0, 12, ["topic", *gen_names]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_fair_admission_across_ready_columns_and_row_groups() -> None: + """Fair admission stays column-balanced when multiple row groups are ready.""" + provider = _mock_provider() + gen_names = ["gen_a", "gen_b", "gen_c"] + + class BarrierSeedGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): + def __init__(self, *args: Any, expected_calls: int, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._expected_calls = expected_calls + self._started = 0 + self._lock = asyncio.Lock() + self._release = asyncio.Event() + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({self.config.name: ["A"] * num_records}) + + async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + async with self._lock: + self._started += 1 + if self._started == self._expected_calls: + self._release.set() + await self._release.wait() + return self.generate_from_scratch(num_records) + + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({c: GenerationStrategy.CELL_BY_CELL for c in gen_names}) + row_groups = [(0, 3), (1, 3)] + generators: dict[str, ColumnGenerator] = { + "topic": BarrierSeedGenerator( + config=_expr_config("topic"), + resource_provider=provider, + expected_calls=len(row_groups), + ), + **{ + name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.05) + for name in gen_names + }, + } + + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_submitted_tasks=8, + max_concurrent_row_groups=2, + trace=True, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + cell_traces = sorted( + (t for t in scheduler.traces if t.column in gen_names), + key=lambda t: t.dispatched_at, + ) + first_six = cell_traces[:6] + first_twelve = cell_traces[:12] + + assert len(cell_traces) == 18 + assert all({t.column for t in first_six[i : i + 3]} == set(gen_names) for i in range(0, 6, 3)) + assert all(sum(1 for t in first_twelve if t.column == column) == 4 for column in gen_names) + assert {t.row_group for t in first_twelve} == {0, 1} + assert all(tracker.is_row_group_complete(rg_id, rg_size, ["topic", *gen_names]) for rg_id, rg_size in row_groups) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None: + """One LLM-bound column cannot consume the whole initial LLM admission window.""" + provider = _mock_provider() + gen_names = ["hot", "peer"] + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({c: GenerationStrategy.CELL_BY_CELL for c in gen_names}) + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + **{ + name: SlowLLMBoundCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.05) + for name in gen_names + }, + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 8)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_submitted_tasks=4, + max_llm_wait_tasks=4, + trace=True, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + first_window = [ + trace.column + for trace in sorted((t for t in scheduler.traces if t.column in gen_names), key=lambda t: t.dispatched_at)[:4] + ] + + assert first_window.count("hot") == 2 + assert first_window.count("peer") == 2 + assert tracker.is_row_group_complete(0, 8, ["topic", *gen_names]) + assert scheduler.get_semaphore_permits() == (4, 4) + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_downstream_interleaves_with_upstream() -> None: """Downstream judge tasks begin before all upstream gen tasks complete (issue #504). diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 01a69dd7a..2ec7b4cd3 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -193,7 +193,7 @@ def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + delta = ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) @@ -201,6 +201,8 @@ def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> No assert len(question_tasks) == 3 assert all(t.task_type == "cell" for t in question_tasks) assert {t.row_index for t in question_tasks} == {0, 1, 2} + assert set(delta.added) == set(question_tasks) + assert delta.removed == () def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: @@ -215,13 +217,16 @@ def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) - ready_ctx.tracker.drop_row(0, 1) + removed = Task(column="question", row_group=0, row_index=1, task_type="cell") + delta = ready_ctx.tracker.drop_row(0, 1) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 2 assert {t.row_index for t in question_tasks} == {0, 2} + assert delta.added == () + assert delta.removed == (removed,) def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) -> None: @@ -230,12 +235,13 @@ def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) ready_ctx.tracker.mark_cell_complete("question", 0, 0) ready_ctx.tracker.mark_cell_complete("question", 0, 1) # question[2] never completes -- drop it instead - ready_ctx.tracker.drop_row(0, 2) + delta = ready_ctx.tracker.drop_row(0, 2) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 assert score_tasks[0].task_type == "batch" + assert score_tasks[0] in delta.added def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: @@ -252,14 +258,17 @@ def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFi def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + delta = None for ri in range(3): - ready_ctx.tracker.mark_cell_complete("question", 0, ri) + delta = ready_ctx.tracker.mark_cell_complete("question", 0, ri) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 assert score_tasks[0].task_type == "batch" + assert delta is not None + assert delta.added == (score_tasks[0],) def test_get_ready_tasks_multiple_row_groups() -> None: @@ -276,6 +285,14 @@ def test_get_ready_tasks_multiple_row_groups() -> None: assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 +def test_frontier_delta_return_is_empty_when_frontier_does_not_change(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + delta = ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + assert delta.empty + + def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py new file mode 100644 index 000000000..b929bce4f --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py @@ -0,0 +1,219 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter + +from data_designer.engine.dataset_builders.utils.fair_task_queue import ( + FairTaskQueue, + TaskGroupKey, + TaskGroupSpec, +) +from data_designer.engine.dataset_builders.utils.task_model import Task + + +def _task(column: str, row_index: int) -> Task: + return Task(column=column, row_group=0, row_index=row_index, task_type="cell") + + +def _group(name: str, *, weight: float = 1.0, admitted_limit: int | None = None) -> TaskGroupSpec: + return TaskGroupSpec( + key=TaskGroupKey(kind="local", identity=(name,)), + weight=weight, + admitted_limit=admitted_limit, + ) + + +def _enqueue(queue: FairTaskQueue, items: list[tuple[Task, TaskGroupSpec]]) -> None: + for task, group in items: + queue.enqueue(task, group) + + +def test_fair_task_queue_equal_groups_round_robins() -> None: + queue = FairTaskQueue() + _enqueue( + queue, + [ + (task, _group(task.column)) + for task in [ + _task("a", 0), + _task("a", 1), + _task("b", 0), + _task("b", 1), + _task("c", 0), + _task("c", 1), + ] + ], + ) + + selected = [queue.admit_next() for _ in range(6)] + + assert [selection.task.column for selection in selected if selection is not None] == ["a", "b", "c", "a", "b", "c"] + + +def test_fair_task_queue_weighted_groups() -> None: + queue = FairTaskQueue() + _enqueue( + queue, + [ + (task, _group(task.column, weight=2 if task.column == "a" else 1)) + for task in [_task("a", i) for i in range(6)] + ] + + [(_task("b", i), _group("b", weight=1)) for i in range(6)], + ) + + selected = [queue.admit_next() for _ in range(6)] + counts = Counter(selection.task.column for selection in selected if selection is not None) + + assert counts == {"a": 4, "b": 2} + + +def test_fair_task_queue_discards_queued_tasks() -> None: + queue = FairTaskQueue() + stale = _task("a", 0) + fresh = _task("a", 1) + + _enqueue(queue, [(stale, _group("a")), (fresh, _group("a"))]) + queue.discard(stale) + + selected = queue.admit_next() + + assert selected is not None + assert selected.task == fresh + assert queue.admit_next() is None + + +def test_fair_task_queue_admitted_cap_skips_saturated_group_with_waiting_peer() -> None: + queue = FairTaskQueue() + capped = _group("a", admitted_limit=1, weight=1_000) + peer = _group("b") + _enqueue( + queue, + [ + (_task("a", 0), capped), + (_task("a", 1), capped), + (_task("b", 0), peer), + (_task("b", 1), peer), + ], + ) + + first = queue.admit_next() + peer_first = queue.admit_next() + selected = queue.admit_next() + + assert first is not None + assert first.task.column == "a" + assert peer_first is not None + assert peer_first.task.column == "b" + assert selected is not None + assert selected.task.column == "b" + + +def test_fair_task_queue_solo_group_can_exceed_admitted_cap() -> None: + queue = FairTaskQueue() + group = _group("a", admitted_limit=1) + first_task = _task("a", 0) + second_task = _task("a", 1) + queue.enqueue(first_task, group) + queue.enqueue(second_task, group) + + first = queue.admit_next() + + assert first is not None + assert first.task == first_task + second = queue.admit_next() + assert second is not None + assert second.task == second_task + assert queue.has_queued_tasks is False + + +def test_fair_task_queue_over_cap_group_yields_to_queued_peer() -> None: + queue = FairTaskQueue() + capped = _group("a", admitted_limit=1) + peer = _group("b") + _enqueue(queue, [(_task("a", i), capped) for i in range(5)]) + + solo_selected = [queue.admit_next() for _ in range(3)] + _enqueue(queue, [(_task("b", i), peer) for i in range(2)]) + peer_selected = [queue.admit_next() for _ in range(2)] + + assert [selection.task.column for selection in solo_selected if selection is not None] == ["a", "a", "a"] + assert [selection.task.column for selection in peer_selected if selection is not None] == ["b", "b"] + + +def test_fair_task_queue_returns_none_when_all_competing_groups_capped() -> None: + queue = FairTaskQueue() + group_a = _group("a", admitted_limit=1) + group_b = _group("b", admitted_limit=1) + _enqueue( + queue, + [ + (_task("a", 0), group_a), + (_task("a", 1), group_a), + (_task("b", 0), group_b), + (_task("b", 1), group_b), + ], + ) + + selected = [queue.admit_next() for _ in range(2)] + + assert [selection.task.column for selection in selected if selection is not None] == ["a", "b"] + assert queue.admit_next() is None + assert queue.has_queued_tasks is True + + +def test_fair_task_queue_release_reopens_saturated_group() -> None: + queue = FairTaskQueue() + group_a = _group("a", admitted_limit=1) + group_b = _group("b", admitted_limit=1) + _enqueue( + queue, + [ + (_task("a", 0), group_a), + (_task("a", 1), group_a), + (_task("b", 0), group_b), + (_task("b", 1), group_b), + ], + ) + first = queue.admit_next() + second = queue.admit_next() + + assert first is not None + assert first.task.column == "a" + assert second is not None + assert second.task.column == "b" + assert queue.admit_next() is None + + queue.release(first.task) + reopened = queue.admit_next() + + assert reopened is not None + assert reopened.task == _task("a", 1) + + +def test_fair_task_queue_no_duplicate_on_repeated_enqueue() -> None: + queue = FairTaskQueue() + task = _task("a", 0) + + queue.enqueue(task, _group("a")) + queue.enqueue(task, _group("a")) + first = queue.admit_next() + + assert first is not None + assert first.task == task + assert queue.admit_next() is None + + +def test_fair_task_queue_discard_where_removes_matching_tasks() -> None: + queue = FairTaskQueue() + _enqueue( + queue, + [(_task(column, i), _group(column)) for column in ["a", "b"] for i in range(2)], + ) + + queue.discard_where(lambda task: task.column == "a") + selected = [queue.admit_next() for _ in range(2)] + + assert [selection.task.column for selection in selected if selection is not None] == ["b", "b"] + assert queue.admit_next() is None diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py new file mode 100644 index 000000000..4e46c07b0 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from data_designer.config.column_configs import ( + CustomColumnConfig, + ExpressionColumnConfig, + GenerationStrategy, + LLMTextColumnConfig, +) +from data_designer.config.custom_column import custom_column_generator +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig +from data_designer.engine.column_generators.generators.base import ColumnGenerator +from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator +from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver +from data_designer.engine.resources.resource_provider import ResourceProvider + +MODEL_ALIAS = "stub" + + +def _expr_config(name: str = "test") -> ExpressionColumnConfig: + return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str") + + +def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: + provider = MagicMock(spec=ResourceProvider) + provider.model_registry = MagicMock() + provider.model_registry.get_model_config.side_effect = lambda model_alias: configs[model_alias] + provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") + return provider + + +class LocalCellGenerator(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data[self.config.name] = "local" + return data + + +class ModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): + @property + def is_llm_bound(self) -> bool: + return True + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data[self.config.name] = "model" + return data + + def get_model_config(self, model_alias: str) -> ModelConfig: + return self.resource_provider.model_registry.get_model_config(model_alias=model_alias) + + def get_model_provider_name(self, model_alias: str) -> str: + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return str(provider.name) + + +def test_scheduling_hint_resolver_local_hint_does_not_touch_model_registry() -> None: + provider = MagicMock(spec=ResourceProvider) + provider.model_registry = MagicMock() + generator = LocalCellGenerator(config=_expr_config("local_col"), resource_provider=provider) + + resolver = SchedulingHintResolver({"local_col": generator}) + + assert resolver.hint_for(generator) == SchedulingHint(group_kind="local") + provider.model_registry.get_model_config.assert_not_called() + provider.model_registry.get_model_provider.assert_not_called() + + +def test_scheduling_hint_resolver_resolves_primary_model_once_per_generator() -> None: + model_config = ModelConfig( + alias=MODEL_ALIAS, + model="model-text", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), + provider="mock-provider", + ) + provider = _provider_with_model_configs({MODEL_ALIAS: model_config}) + column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) + generator = ModelCellGenerator(config=column_config, resource_provider=provider) + + resolver = SchedulingHintResolver({"answer": generator, "answer_again": generator}) + hint = resolver.hint_for(generator) + + assert hint.group_kind == "model" + assert hint.identity_prefix[:2] == ("mock-provider", "model-text") + assert hint.weight == 3 + assert provider.model_registry.get_model_config.call_count == 1 + assert provider.model_registry.get_model_provider.call_count == 1 + + +def test_scheduling_hint_resolver_falls_back_to_custom_model_hint_with_debug( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = MagicMock(spec=ResourceProvider) + provider.model_registry = MagicMock() + provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") + provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") + column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) + generator = ModelCellGenerator(config=column_config, resource_provider=provider) + + with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): + resolver = SchedulingHintResolver({"answer": generator}) + + hint = resolver.hint_for(generator) + + assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=(MODEL_ALIAS,), weight=1) + fallback_records = [ + record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() + ] + assert len(fallback_records) == 1 + assert "answer" in fallback_records[0].getMessage() + assert MODEL_ALIAS in fallback_records[0].getMessage() + assert fallback_records[0].exc_info is not None + + +def test_scheduling_hint_resolver_partial_alias_fallback_preserves_resolved_weight() -> None: + @custom_column_generator(model_aliases=["resolved", "missing"]) + def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: + row["custom_llm"] = "value" + return row + + provider = _provider_with_model_configs( + { + "resolved": ModelConfig( + alias="resolved", + model="model-resolved", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=7), + provider="mock-provider", + ) + } + ) + config = CustomColumnConfig(name="custom_llm", generator_function=gen_with_models) + generator = CustomColumnGenerator(config=config, resource_provider=provider) + + resolver = SchedulingHintResolver({"custom_llm": generator}) + hint = resolver.hint_for(generator) + + assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=("missing", "resolved"), weight=7)