Skip to content

Commit e971e9e

Browse files
committed
Refactor validation a bit and add test.
1 parent 3741f5f commit e971e9e

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

src/modelplane/evaluator/dag.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,14 @@ def _validate_and_build(self) -> None:
9191
if self._validated:
9292
return
9393

94-
# check that all route targets reference registered nodes or instances of the output type
94+
# check that all route targets reference registered nodes or instances
95+
# of the output type, and that all Arbiters have compatible output types
9596
for node_name, node in self._nodes.items():
97+
if isinstance(node, Arbiter):
98+
if not issubclass(node.output_type, self.output_type):
99+
raise ValueError(
100+
f"Node {node_name} is an Arbiter with output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}."
101+
)
96102
for target in node.all_routes():
97103
if target not in self._nodes and not isinstance(
98104
target, self.output_type
@@ -127,16 +133,6 @@ def _validate_and_build(self) -> None:
127133
nodes_in_cycle = set(self._nodes) - set(ordered)
128134
raise ValueError(f"DAG contains a cycle. Nodes in cycle: {nodes_in_cycle}")
129135

130-
# check all terminal Arbiter nodes have correct output types
131-
terminal_nodes = [n for n in self._nodes if not all_routes.get(n)]
132-
for terminal in terminal_nodes:
133-
node = self._nodes[terminal]
134-
if isinstance(node, Arbiter):
135-
if not issubclass(node.output_type, self.output_type):
136-
raise ValueError(
137-
f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}."
138-
)
139-
140136
# build predecessors
141137
for name, node in self._nodes.items():
142138
for target in node.all_routes():

tests/unit/evaluator/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PromptLengthGate,
2323
ThresholdArbiter,
2424
UnexpectedArbiter,
25+
UnexpectedOutput,
2526
UpperCaser,
2627
UpperCaseScorer,
2728
)
@@ -183,3 +184,18 @@ def bad_dag_with_bad_arbiter():
183184
dag = EvaluatorDAG("test", output_type=Safety)
184185
dag.add_node(BadArbiter(name="bad_arbiter"))
185186
return dag
187+
188+
189+
@pytest.fixture
190+
def bad_one_step_dag():
191+
return (
192+
EvaluatorDAG("one_step", output_type=Safety)
193+
.add_node(
194+
AlwaysFalse(
195+
name="gate",
196+
routes_true=[UnexpectedOutput()],
197+
routes_false=["always_unsafe"],
198+
)
199+
)
200+
.add_node(AlwaysUnsafe(name="always_unsafe"))
201+
)

tests/unit/evaluator/test_dag.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def test_dag_with_bad_arbiter(bad_dag_with_bad_arbiter, sample_ctx):
5555
bad_dag_with_bad_arbiter.run(sample_ctx)
5656

5757

58+
def test_dag_with_bad_output_route(bad_one_step_dag, sample_ctx):
59+
with pytest.raises(
60+
ValueError,
61+
match=r"incompatible output",
62+
):
63+
bad_one_step_dag.run(sample_ctx)
64+
65+
5866
def test_dag_run(simple_dag, sample_ctx):
5967
result = simple_dag.run(sample_ctx)
6068
assert result.name == "UNSAFE"

0 commit comments

Comments
 (0)