diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py index 8112d87e8..208fa6d80 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -11,8 +11,8 @@ SamplerMultiColumnConfig, SeedDatasetMultiColumnConfig, ) -from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError +from data_designer.engine.dataset_builders.utils.execution_graph import topologically_sort_column_configs def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[DatasetBuilderColumnConfigT]: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dag.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dag.py deleted file mode 100644 index fd019137f..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/dag.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging - -import data_designer.lazy_heavy_imports as lazy -from data_designer.config.column_types import ColumnConfigT -from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag -from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -from data_designer.logging import LOG_INDENT - -logger = logging.getLogger(__name__) - - -def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: - dag = lazy.nx.DiGraph() - - non_dag_column_config_list = [ - col for col in column_configs if not column_type_used_in_execution_dag(col.column_type) - ] - dag_column_config_dict = { - col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type) - } - - if len(dag_column_config_dict) == 0: - return non_dag_column_config_list - - side_effect_dict = {n: list(c.side_effect_columns) for n, c in dag_column_config_dict.items()} - - logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph") - for name, col in dag_column_config_dict.items(): - dag.add_node(name) - for req_col_name in col.required_columns: - if req_col_name in list(dag_column_config_dict.keys()): - logger.debug(f"{LOG_INDENT}🔗 `{name}` depends on `{req_col_name}`") - dag.add_edge(req_col_name, name) - - # If the required column is a side effect of another column, - # add an edge from the parent column to the current column. - elif req_col_name in sum(side_effect_dict.values(), []): - for parent, cols in side_effect_dict.items(): - if req_col_name in cols: - logger.debug(f"{LOG_INDENT}🔗 `{name}` depends on `{parent}` via `{req_col_name}`") - dag.add_edge(parent, name) - break - - if not lazy.nx.is_directed_acyclic_graph(dag): - raise DAGCircularDependencyError( - "🛑 The Data Designer column configurations contain cyclic dependencies. Please " - "inspect the column configurations and ensure they can be sorted without " - "circular references." - ) - - sorted_columns = non_dag_column_config_list - sorted_columns.extend([dag_column_config_dict[n] for n in list(lazy.nx.topological_sort(dag))]) - - return sorted_columns diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 29db09c83..f73790a79 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -3,16 +3,22 @@ from __future__ import annotations +import logging import math from collections import deque from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.column_types import ColumnConfigT +from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag from data_designer.engine.dataset_builders.multi_column_configs import ( DatasetBuilderColumnConfigT, MultiColumnConfig, ) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.task_model import SliceRef +from data_designer.logging import LOG_INDENT + +logger = logging.getLogger(__name__) class ExecutionGraph: @@ -258,3 +264,55 @@ def to_mermaid(self) -> str: for dep in sorted(self._upstream.get(col, set())): lines.append(f" {dep} --> {col}") return "\n".join(lines) + + +def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: + non_dag_cols = [col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)] + dag_col_dict = {col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)} + + if not dag_col_dict: + return non_dag_cols + + # side_effect_col_name -> producing column name + side_effect_map: dict[str, str] = {} + for name, col in dag_col_dict.items(): + for se_col in col.side_effect_columns: + side_effect_map[se_col] = name + + def resolve(col_name: str) -> str | None: + if col_name in dag_col_dict: + return col_name + return side_effect_map.get(col_name) + + upstream: dict[str, set[str]] = {name: set() for name in dag_col_dict} + downstream: dict[str, set[str]] = {name: set() for name in dag_col_dict} + + logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph") + for name, col in dag_col_dict.items(): + for req in col.required_columns: + resolved = resolve(req) + if resolved is None or resolved == name: + continue + logger.debug(f"{LOG_INDENT}🔗 `{name}` depends on `{resolved}`") + upstream[name].add(resolved) + downstream[resolved].add(name) + + in_degree = {name: len(ups) for name, ups in upstream.items()} + queue: deque[str] = deque(name for name, deg in in_degree.items() if deg == 0) + order: list[str] = [] + while queue: + name = queue.popleft() + order.append(name) + for child in downstream.get(name, set()): + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + if len(order) != len(dag_col_dict): + raise DAGCircularDependencyError( + "🛑 The Data Designer column configurations contain cyclic dependencies. Please " + "inspect the column configurations and ensure they can be sorted without " + "circular references." + ) + + return non_dag_cols + [dag_col_dict[n] for n in order] diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dag.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dag.py index 8328a8f9d..95812ef43 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dag.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_dag.py @@ -17,8 +17,8 @@ from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig -from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError +from data_designer.engine.dataset_builders.utils.execution_graph import topologically_sort_column_configs MODEL_ALIAS = "stub-model-alias" @@ -78,14 +78,14 @@ def test_dag_construction(): assert sorted_column_configs[0].column_type == DataDesignerColumnType.SAMPLER - assert [c.name for c in sorted_column_configs[1:]] == [ - "test_code", - "test_validation", - "depends_on_validation", - "test_judge", - "test_code_and_depends_on_validation_reasoning_traces", - "uses_all_the_stuff", - ] + names = [c.name for c in sorted_column_configs[1:]] + assert names[0] == "test_code" + assert names[1] == "test_validation" + assert names[2] == "depends_on_validation" + # test_judge and test_code_and_depends_on_validation_reasoning_traces have no mutual + # dependency, so their relative order is not guaranteed by topological sort. + assert set(names[3:5]) == {"test_judge", "test_code_and_depends_on_validation_reasoning_traces"} + assert names[5] == "uses_all_the_stuff" def test_circular_dependencies():