Skip to content
5 changes: 3 additions & 2 deletions architecture/dataset-builders.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`):
4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler`
5. Hooks `ProcessorRunner` for pre-batch and post-batch stages

`AsyncTaskScheduler` runs on a dedicated async loop with semaphore-based concurrency, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially.
`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, semaphore-based capacity limits, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks are admitted through a virtual-time fair queue so one hot column or model-backed generator cannot consume the whole submission window before peer work gets a turn.

### Execution Graph

Expand Down Expand Up @@ -123,7 +123,7 @@ DatasetBuilder.build()
β†’ CompletionTracker.with_graph()
β†’ AsyncTaskScheduler(semaphores, salvage_rounds)
β†’ scheduler.run()
β†’ for each row group, dispatch ready tasks from frontier
β†’ for each row group, fairly admit ready tasks from frontier
β†’ tasks execute generators, update CompletionTracker
β†’ checkpoints via RowGroupBufferManager
β†’ collect TaskTraces, emit telemetry
Expand All @@ -133,6 +133,7 @@ DatasetBuilder.build()

- **Dual execution engines behind one API.** The sequential engine is simpler and easier to debug; the async engine adds row-group parallelism for throughput. Users switch via an environment variable without changing their code.
- **DAG-driven ordering** ensures columns with dependencies (e.g., a judge column that depends on a text column) are generated in the correct order, regardless of the order they appear in the config.
- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. Global semaphores still bound memory/coroutine growth, while per-group virtual-time queues prevent a large ready frontier from degenerating into a column-by-column wave. LLM admission caps are peer-sensitive: a solo model group can fill available global capacity, but once another scheduling group has queued work the saturated group yields until peers get admission slots or admitted tasks complete.
- **Salvage rounds in async mode** retry failed tasks after all other tasks in a round complete, improving resilience against transient LLM failures without blocking the entire generation.
- **Unified DAG construction.** `topologically_sort_column_configs` (in `execution_graph.py`) determines column ordering using Kahn's algorithm; the runtime `ExecutionGraph` adds strategy-aware dependency tracking for the async scheduler.

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@

from data_designer.engine.dataset_builders.async_scheduler import (
DEFAULT_TASK_POOL_SIZE,
LLM_WAIT_POOL_MULTIPLIER,
GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER,
AsyncTaskScheduler,
)
from data_designer.engine.dataset_builders.utils.async_concurrency import (
AsyncConcurrentExecutor,
ensure_async_engine_loop,
)
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker
from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta
from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager


Expand Down Expand Up @@ -996,13 +996,18 @@ def _prepare_async_run(

# Pre-batch processor callback: runs after seed tasks complete for a row group.
# If it raises, the scheduler propagates the error as DatasetGenerationError (fail-fast).
def on_seeds_complete(rg_id: int, rg_size: int) -> None:
def on_seeds_complete(rg_id: int, rg_size: int) -> FrontierDelta:
df = buffer_manager.get_dataframe(rg_id)
df = self._processor_runner.run_pre_batch_on_df(df, strict_row_count=True)
buffer_manager.replace_dataframe(rg_id, df)
deltas: list[FrontierDelta] = []
for ri in range(rg_size):
if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri):
tracker.drop_row(rg_id, ri)
deltas.append(tracker.drop_row(rg_id, ri))
return FrontierDelta(
added=tuple(task for delta in deltas for task in delta.added),
removed=tuple(task for delta in deltas for task in delta.removed),
)

# Post-batch processor callback: runs after all columns, before finalization.
def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
Expand All @@ -1022,7 +1027,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
row_groups=row_groups,
buffer_manager=buffer_manager,
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate),
max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER * aggregate),
on_finalize_row_group=on_finalize_row_group,
on_seeds_complete=(
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

from data_designer.config.column_configs import GenerationStrategy
Expand All @@ -13,6 +14,18 @@
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph


@dataclass(frozen=True)
class FrontierDelta:
"""Tasks added to or removed from the ready frontier by a tracker mutation."""

added: tuple[Task, ...] = ()
removed: tuple[Task, ...] = ()

@property
def empty(self) -> bool:
return not self.added and not self.removed


class CompletionTracker:
"""Tracks which cells (column, row_group, row_index) are done.

Expand Down Expand Up @@ -42,24 +55,34 @@ def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) ->
tracker._row_group_sizes = dict(row_groups)
return tracker

def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None:
def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> FrontierDelta:
self._validate_row_group(row_group)
self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete")
self._completed[row_group][column].add(row_index)
removed: list[Task] = []
added: list[Task] = []
if self._graph is not None:
self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell"))
self._enqueue_downstream(column, row_group, row_index=row_index)
task = Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")
if self._discard_frontier_task(task):
removed.append(task)
added.extend(self._enqueue_downstream(column, row_group, row_index=row_index))
return self._record_delta(added=added, removed=removed)

def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None:
def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> FrontierDelta:
expected = self._validate_row_group(row_group)
self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete")
if expected is not None and row_group_size != expected:
raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}")
self._completed[row_group][column] = set(range(row_group_size))
self._batch_complete[row_group].add(column)
removed: list[Task] = []
added: list[Task] = []
if self._graph is not None:
self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch"))
self._enqueue_downstream(column, row_group, row_index=None)
task = Task(column=column, row_group=row_group, row_index=None, task_type="batch")
if self._discard_frontier_task(task):
removed.append(task)
added.extend(self._enqueue_downstream(column, row_group, row_index=None))
return self._record_delta(added=added, removed=removed)

def is_complete(self, ref: SliceRef) -> bool:
return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set())
Expand Down Expand Up @@ -89,15 +112,20 @@ def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool:
dropped = self._dropped.get(row_group_index, set())
return all(ri in completed or ri in dropped for ri in range(rg_size))

def drop_row(self, row_group: int, row_index: int) -> None:
def drop_row(self, row_group: int, row_index: int) -> FrontierDelta:
self._validate_row_group(row_group)
self._dropped[row_group].add(row_index)
removed: list[Task] = []
added: list[Task] = []
if self._graph is not None:
# Remove cell tasks for this row from the frontier
for col in self._graph.columns:
self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell"))
task = Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")
if self._discard_frontier_task(task):
removed.append(task)
# Dropping a row may unblock batch downstream tasks
self._reevaluate_batch_tasks(row_group)
added.extend(self._reevaluate_batch_tasks(row_group))
return self._record_delta(added=added, removed=removed)

def is_dropped(self, row_group: int, row_index: int) -> bool:
return row_index in self._dropped.get(row_group, set())
Expand Down Expand Up @@ -129,6 +157,10 @@ def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None =
t for t in self._frontier if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs)
]

def is_frontier_task(self, task: Task) -> bool:
"""Return whether *task* is still in the ready frontier."""
return task in self._frontier

def seed_frontier(self) -> None:
"""Populate the frontier with root tasks (columns with no upstream deps).

Expand All @@ -147,10 +179,26 @@ def seed_frontier(self) -> None:
else:
self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch"))

def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None:
def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta:
return FrontierDelta(added=tuple(added), removed=tuple(removed))

def _add_frontier_task(self, task: Task) -> bool:
if task in self._frontier:
return False
self._frontier.add(task)
return True

def _discard_frontier_task(self, task: Task) -> bool:
if task not in self._frontier:
return False
self._frontier.remove(task)
return True

def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> list[Task]:
"""Add newly-ready downstream tasks to the frontier."""
if self._graph is None:
raise RuntimeError("This method requires a graph to be set.")
added: list[Task] = []
rg_completed = self._completed.get(row_group, {})
rg_dropped = self._dropped.get(row_group, set())
rg_batch_complete = self._batch_complete.get(row_group, set())
Expand All @@ -175,7 +223,8 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None
and all(row_index in s for s in cell_up_completed)
):
task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell")
self._frontier.add(task)
if self._add_frontier_task(task):
added.append(task)
else:
# Batch completion: check all non-dropped, non-complete rows
down_completed = rg_completed.get(down, set())
Expand All @@ -184,19 +233,23 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None
continue
if all(ri in s for s in cell_up_completed):
task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell")
self._frontier.add(task)
if self._add_frontier_task(task):
added.append(task)
else:
# FULL_COLUMN downstream: ready when all cell upstreams are fully complete
if down not in rg_batch_complete and self._are_cell_ups_complete(
cell_ups, rg_completed, rg_size, rg_dropped
):
task = Task(column=down, row_group=row_group, row_index=None, task_type="batch")
self._frontier.add(task)
if self._add_frontier_task(task):
added.append(task)
return added

def _reevaluate_batch_tasks(self, row_group: int) -> None:
def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]:
"""Check if any batch tasks became ready after a row was dropped."""
if self._graph is None:
raise RuntimeError("This method requires a graph to be set.")
added: list[Task] = []
rg_completed = self._completed.get(row_group, {})
rg_dropped = self._dropped.get(row_group, set())
rg_batch_complete = self._batch_complete.get(row_group, set())
Expand All @@ -212,7 +265,9 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None:
continue
if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped):
task = Task(column=col, row_group=row_group, row_index=None, task_type="batch")
self._frontier.add(task)
if self._add_frontier_task(task):
added.append(task)
return added

def _are_cell_ups_complete(
self,
Expand Down
Loading
Loading