Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
SamplerMultiColumnConfig,
SeedDatasetMultiColumnConfig,
)
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError
from data_designer.engine.dataset_builders.utils.execution_graph import topologically_sort_column_configs


def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[DatasetBuilderColumnConfigT]:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@

from __future__ import annotations

import logging
import math
from collections import deque

from data_designer.config.column_configs import GenerationStrategy
from data_designer.config.column_types import ColumnConfigT
from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag
from data_designer.engine.dataset_builders.multi_column_configs import (
DatasetBuilderColumnConfigT,
MultiColumnConfig,
)
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.task_model import SliceRef
from data_designer.logging import LOG_INDENT

logger = logging.getLogger(__name__)


class ExecutionGraph:
Expand Down Expand Up @@ -258,3 +264,55 @@ def to_mermaid(self) -> str:
for dep in sorted(self._upstream.get(col, set())):
lines.append(f" {dep} --> {col}")
return "\n".join(lines)


def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]:
non_dag_cols = [col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)]
dag_col_dict = {col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)}

if not dag_col_dict:
return non_dag_cols

# side_effect_col_name -> producing column name
side_effect_map: dict[str, str] = {}
for name, col in dag_col_dict.items():
for se_col in col.side_effect_columns:
side_effect_map[se_col] = name

def resolve(col_name: str) -> str | None:
if col_name in dag_col_dict:
return col_name
return side_effect_map.get(col_name)

upstream: dict[str, set[str]] = {name: set() for name in dag_col_dict}
downstream: dict[str, set[str]] = {name: set() for name in dag_col_dict}

logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph")
for name, col in dag_col_dict.items():
for req in col.required_columns:
resolved = resolve(req)
if resolved is None or resolved == name:
continue
logger.debug(f"{LOG_INDENT}🔗 `{name}` depends on `{resolved}`")
upstream[name].add(resolved)
downstream[resolved].add(name)

in_degree = {name: len(ups) for name, ups in upstream.items()}
queue: deque[str] = deque(name for name, deg in in_degree.items() if deg == 0)
order: list[str] = []
while queue:
name = queue.popleft()
order.append(name)
for child in downstream.get(name, set()):
in_degree[child] -= 1
if in_degree[child] == 0:
queue.append(child)

if len(order) != len(dag_col_dict):
raise DAGCircularDependencyError(
"🛑 The Data Designer column configurations contain cyclic dependencies. Please "
"inspect the column configurations and ensure they can be sorted without "
"circular references."
)

return non_dag_cols + [dag_col_dict[n] for n in order]
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.execution_graph import topologically_sort_column_configs

MODEL_ALIAS = "stub-model-alias"

Expand Down Expand Up @@ -78,14 +78,14 @@ def test_dag_construction():

assert sorted_column_configs[0].column_type == DataDesignerColumnType.SAMPLER

assert [c.name for c in sorted_column_configs[1:]] == [
"test_code",
"test_validation",
"depends_on_validation",
"test_judge",
"test_code_and_depends_on_validation_reasoning_traces",
"uses_all_the_stuff",
]
names = [c.name for c in sorted_column_configs[1:]]
assert names[0] == "test_code"
assert names[1] == "test_validation"
assert names[2] == "depends_on_validation"
# test_judge and test_code_and_depends_on_validation_reasoning_traces have no mutual
# dependency, so their relative order is not guaranteed by topological sort.
assert set(names[3:5]) == {"test_judge", "test_code_and_depends_on_validation_reasoning_traces"}
assert names[5] == "uses_all_the_stuff"


def test_circular_dependencies():
Expand Down
Loading