-
Notifications
You must be signed in to change notification settings - Fork 170
Expand file tree
/
Copy pathtest_dag.py
More file actions
137 lines (121 loc) · 4.73 KB
/
test_dag.py
File metadata and controls
137 lines (121 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import pytest
from data_designer.config.column_configs import (
CustomColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
LLMTextColumnConfig,
SamplerColumnConfig,
Score,
ValidationColumnConfig,
)
from data_designer.config.column_types import DataDesignerColumnType
from data_designer.config.custom_column import custom_column_generator
from data_designer.config.sampler_params import SamplerType
from data_designer.config.utils.code_lang import CodeLang
from data_designer.config.validator_params import CodeValidatorParams
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
from data_designer.engine.dataset_builders.utils.dag import topologically_sort_column_configs
from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError
MODEL_ALIAS = "stub-model-alias"
def test_dag_construction():
column_configs = []
column_configs.append(
SamplerMultiColumnConfig(
columns=[SamplerColumnConfig(name="test_id", sampler_type=SamplerType.UUID, params={})]
)
)
column_configs.append(
LLMCodeColumnConfig(
name="test_code",
prompt="Write some zig but call it Python.",
code_lang=CodeLang.PYTHON,
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMCodeColumnConfig(
name="depends_on_validation",
prompt="Write {{ test_validation.python_linter_score }}.",
code_lang=CodeLang.PYTHON,
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMJudgeColumnConfig(
name="test_judge",
prompt="Judge this {{ test_code }} {{ depends_on_validation }}",
scores=[Score(name="test_score", description="test", options={0: "Not Good", 1: "Good"})],
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
ExpressionColumnConfig(
name="uses_all_the_stuff", expr="{{ test_code }} {{ depends_on_validation }} {{ test_judge }}"
)
)
column_configs.append(
ExpressionColumnConfig(
name="test_code_and_depends_on_validation_reasoning_traces",
expr="{{ test_code__trace }} {{ depends_on_validation }}",
)
)
column_configs.append(
ValidationColumnConfig(
name="test_validation",
target_columns=["test_code"],
validator_type="code",
validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON),
)
)
sorted_column_configs = topologically_sort_column_configs(column_configs)
assert sorted_column_configs[0].column_type == DataDesignerColumnType.SAMPLER
assert [c.name for c in sorted_column_configs[1:]] == [
"test_code",
"test_validation",
"depends_on_validation",
"test_judge",
"test_code_and_depends_on_validation_reasoning_traces",
"uses_all_the_stuff",
]
def test_circular_dependencies():
column_configs = []
column_configs.append(
SamplerMultiColumnConfig(
columns=[SamplerColumnConfig(name="test_id", sampler_type=SamplerType.UUID, params={})]
)
)
column_configs.append(
LLMTextColumnConfig(
name="col_1",
prompt="I need you {{ col_2 }}",
model_alias=MODEL_ALIAS,
)
)
column_configs.append(
LLMTextColumnConfig(
name="col_2",
prompt="I need you {{ col_1 }}",
model_alias=MODEL_ALIAS,
)
)
with pytest.raises(DAGCircularDependencyError, match="cyclic dependencies"):
topologically_sort_column_configs(column_configs)
def test_duplicate_side_effect_producers_raises() -> None:
"""Two custom columns declaring the same side-effect column is a configuration error."""
@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
def gen_a(row: dict[str, Any]) -> dict[str, Any]:
return row
@custom_column_generator(required_columns=["text"], side_effect_columns=["shared_col"])
def gen_b(row: dict[str, Any]) -> dict[str, Any]:
return row
column_configs = [
LLMTextColumnConfig(name="text", prompt="hello", model_alias=MODEL_ALIAS),
CustomColumnConfig(name="col_a", generator_function=gen_a),
CustomColumnConfig(name="col_b", generator_function=gen_b),
]
with pytest.raises(ConfigCompilationError, match="already produced by"):
topologically_sort_column_configs(column_configs)