-
Notifications
You must be signed in to change notification settings - Fork 170
Expand file tree
/
Copy pathtest_dag.py
More file actions
113 lines (103 loc) · 3.97 KB
/
test_dag.py
File metadata and controls
113 lines (103 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
from data_designer.config.column_configs import (
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMTextColumnConfig,
SamplerColumnConfig,
Score,
ValidationColumnConfig,
)
from data_designer.config.column_types import DataDesignerColumnType
from data_designer.config.sampler_params import SamplerType
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
from data_designer.engine.dataset_builders.utils.execution_graph import topologically_sort_column_configs
MODEL_ALIAS = "stub-model-alias"
def test_dag_construction():
column_configs = []
column_configs.append(
SamplerMultiColumnConfig(
columns=[SamplerColumnConfig(name="test_id", sampler_type=SamplerType.UUID, params={})]
)
)
column_configs.append(
LLMCodeColumnConfig(
name="test_code",
prompt="Write some zig but call it Python.",
code_lang=CodeLang.PYTHON,
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMCodeColumnConfig(
name="depends_on_validation",
prompt="Write {{ test_validation.python_linter_score }}.",
code_lang=CodeLang.PYTHON,
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMJudgeColumnConfig(
name="test_judge",
prompt="Judge this {{ test_code }} {{ depends_on_validation }}",
scores=[Score(name="test_score", description="test", options={0: "Not Good", 1: "Good"})],
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
ExpressionColumnConfig(
name="uses_all_the_stuff", expr="{{ test_code }} {{ depends_on_validation }} {{ test_judge }}"
)
)
column_configs.append(
ExpressionColumnConfig(
name="test_code_and_depends_on_validation_reasoning_traces",
expr="{{ test_code__trace }} {{ depends_on_validation }}",
)
)
column_configs.append(
ValidationColumnConfig(
name="test_validation",
target_columns=["test_code"],
validator_type="code",
validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON),
)
)
sorted_column_configs = topologically_sort_column_configs(column_configs)
assert sorted_column_configs[0].column_type == DataDesignerColumnType.SAMPLER
names = [c.name for c in sorted_column_configs[1:]]
assert names[0] == "test_code"
assert names[1] == "test_validation"
assert names[2] == "depends_on_validation"
# test_judge and test_code_and_depends_on_validation_reasoning_traces have no mutual
# dependency, so their relative order is not guaranteed by topological sort.
assert set(names[3:5]) == {"test_judge", "test_code_and_depends_on_validation_reasoning_traces"}
assert names[5] == "uses_all_the_stuff"
def test_circular_dependencies():
column_configs = []
column_configs.append(
SamplerMultiColumnConfig(
columns=[SamplerColumnConfig(name="test_id", sampler_type=SamplerType.UUID, params={})]
)
)
column_configs.append(
LLMTextColumnConfig(
name="col_1",
prompt="I need you {{ col_2 }}",
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMTextColumnConfig(
name="col_2",
prompt="I need you {{ col_1 }}",
model_alias=MODEL_ALIAS,
)
)
with pytest.raises(DAGCircularDependencyError, match="cyclic dependencies"):
topologically_sort_column_configs(column_configs)