From b66aca14c64639a36f5043ae6428ad9e092a2d39 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Thu, 2 Apr 2026 16:19:03 -0400 Subject: [PATCH 01/19] Initial DAG evaluator commit. --- src/modelplane/evaluator/context.py | 17 ++ src/modelplane/evaluator/dag.py | 287 ++++++++++++++++++++++++++++ src/modelplane/evaluator/nodes.py | 164 ++++++++++++++++ src/modelplane/evaluator/outputs.py | 30 +++ tests/unit/evaluator/conftest.py | 105 ++++++++++ tests/unit/evaluator/mocks.py | 109 +++++++++++ tests/unit/evaluator/test_dag.py | 9 + tests/unit/evaluator/test_nodes.py | 65 +++++++ 8 files changed, 786 insertions(+) create mode 100644 src/modelplane/evaluator/context.py create mode 100644 src/modelplane/evaluator/dag.py create mode 100644 src/modelplane/evaluator/nodes.py create mode 100644 src/modelplane/evaluator/outputs.py create mode 100644 tests/unit/evaluator/conftest.py create mode 100644 tests/unit/evaluator/mocks.py create mode 100644 tests/unit/evaluator/test_dag.py create mode 100644 tests/unit/evaluator/test_nodes.py diff --git a/src/modelplane/evaluator/context.py b/src/modelplane/evaluator/context.py new file mode 100644 index 0000000..c30e661 --- /dev/null +++ b/src/modelplane/evaluator/context.py @@ -0,0 +1,17 @@ +from typing import Any + + +class EvalContext: + """Context state passed around during DAG execution.""" + + def __init__(self, prompt_text: str, response: str) -> None: + self.prompt_text = prompt_text + self.response = response + self._parent_outputs = {} + + def set_parent_outputs(self, outputs: dict[str, Any]) -> None: + self._parent_outputs = outputs + + def parent_outputs(self) -> list[Any]: + """Return the NodeOutput for a specific node, or None if it was skipped.""" + return list(self._parent_outputs.values()) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py new file mode 100644 index 0000000..fa46306 --- /dev/null +++ b/src/modelplane/evaluator/dag.py @@ -0,0 +1,287 @@ +"""DAGAnnotator and EvaluatorDAG implementation.""" + +import collections +from dataclasses import dataclass, field +import functools +from typing import Any, Optional + +import pandas as pd +from modelgauge.annotator import Annotator + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.nodes import ( + Arbiter, + EvaluatorDAGNode, + Gate, + Output, +) +from modelgauge.prompt import ChatPrompt, TextPrompt +from modelgauge.prompt_formatting import format_chat +from modelgauge.sut import SUTResponse +from modelgauge.annotation import SafetyAnnotation + + +def requires_validate_and_build(method): + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + self._validate_and_build() + return method(self, *args, **kwargs) + + return wrapper + + +class EvaluatorDAG: + """DAG of EvaluatorNodes. + + Usage: + + refusal_gate = MyRefusalGate("RefusalGate", routes_true=[NONVIOLATING], routes_false=["NonRefusal"]) + eval_non_refusal = MyNonRefusalEvaluator("NonRefusal", routes=["Arbiter"]) + arbiter = MyArbiter("Arbiter", routes_true=[VIOLATING], routes_false=[NONVIOLATING]) + + dag = ( + EvaluatorDAG("refusal_evaluator", outputs=[NONVIOLATING, VIOLATING]) + .add_node(refusal_gate) + .add_node(eval_non_refusal) + .add_node(arbiter) + ) + # run single + result = dag.run(prompt_uid="123", prompt_text="...", response="...") + # run batch + results_df = dag.run_dataframe(df) + """ + + def __init__(self, name: str, outputs: list[Output]) -> None: + self.name = name + self._nodes: dict[str, EvaluatorDAGNode] = {} + self._root_nodes: list[str] = [] + self._ordered: list[str] = [] + self._validated: bool = False + self._predecessors: dict[str, list[str]] = collections.defaultdict(list) + self._outputs = {output.name: output for output in outputs} + + def add_node( + self, + node: EvaluatorDAGNode, + ) -> "EvaluatorDAG": + """Register a node with its routes.""" + + if node.name in self._all_names(): + raise ValueError( + f"A different node named {node.name!r} is already registered." + ) + self._nodes[node.name] = node + self._validated = False + return self + + def _all_names(self) -> dict[str, EvaluatorDAGNode | Output]: + return {**self._nodes, **self._outputs} + + def _validate_and_build(self) -> None: + """ + Validate the DAG: + - All routes reference registered nodes. + - No cycles. + - All paths lead to an Output node. + - All Output nodes are declared as outputs in the DAG constructor. + + Build: + - _predecessors: dict mapping node name to list of parent node names (for context during execution) + - _root_nodes: list of node names with no incoming routes (starting points) + - _ordered: list of node names in topological order (valid execution order) + """ + # skip validation if we've already done it and the DAG hasn't changed + if self._validated: + return + + all_named_entities = self._all_names() + # check that all route targets reference registered nodes + for node_name, node in self._nodes.items(): + for target in node.all_routes(): + if target not in all_named_entities: + raise ValueError( + f"Node {node_name!r} routes to unregistered node {target!r}." + ) + + # check for cycles (kahn's algorithm) + all_routes = {name: node.all_routes() for name, node in self._nodes.items()} + in_degree: dict[str, int] = {n: 0 for n in self._nodes} + for route in all_routes.values(): + for t in route: + in_degree[t] += 1 + + root_nodes = [n for n in self._nodes if in_degree[n] == 0] + queue = collections.deque(root_nodes) + ordered: list[str] = [] + while queue: + current = queue.popleft() + ordered.append(current) + for child in all_routes.get(current, []): + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + if len(ordered) != len(self._nodes): + # missing nodes + missing = set(self._nodes) - set(ordered) + raise ValueError(f"Graph contains a cycle. Missing nodes: {missing}") + + # check all terminal nodes are Output nodes + terminal_nodes = [n for n in self._nodes if not all_routes.get(n)] + for terminal in terminal_nodes: + entity = all_named_entities[terminal] + if isinstance(entity, Output) and terminal not in self._outputs: + raise ValueError( + f"Terminal Output node {terminal!r} is not declared as an output in the DAG constructor." + ) + elif isinstance(entity, Arbiter): + if any(o.name not in self._outputs for o in entity.outputs()): + raise ValueError( + f"Terminal Arbiter node {terminal!r} has output(s) that are not declared as outputs in the DAG constructor." + ) + else: + raise ValueError( + f"Terminal node {terminal!r} is not an Output or Arbiter node." + ) + + # get predecessors + for name, node in self._nodes.items(): + for target in node.all_routes(): + self._predecessors[target].append(name) + + self._validated = True + self._root_nodes = root_nodes + self._ordered = ordered + + @requires_validate_and_build + def run( + self, + ctx: EvalContext, + ) -> Output: + """ + Execute the DAG on a single prompt/response. + """ + active_nodes = self._root_nodes + outputs: dict[str, Any] = {} + while active_nodes: + next_active = [] + for node_name in active_nodes: + print("Running node:", node_name) + # set parent outputs in context for this node + ctx.set_parent_outputs( + { + pred: outputs[pred] + for pred in self._predecessors[node_name] + if pred in outputs + } + ) + # run the node + node = self._nodes[node_name] + output = node.run(ctx) + if isinstance(output, Output): + return output + outputs[node_name] = output + # see which nodes to activate next based on output and routing + next_active.extend(node.next_nodes()) + active_nodes = next_active + raise ValueError("DAG execution completed without reaching an Output node.") + + @requires_validate_and_build + def run_dataframe( + self, + df: pd.DataFrame, + prompt_text_col: str = "prompt_text", + response_col: str = "sut_response", + ) -> pd.DataFrame: + """Run the DAG over every row of a DataFrame.""" + + def _run_row(row: Any) -> Output: + ctx = EvalContext( + prompt_text=str(row[prompt_text_col]), + response=str(row[response_col]), + ) + return self.run(ctx) + + records = [_run_row(row) for _, row in df.iterrows()] + + result_df = pd.DataFrame(records, index=df.index) + return pd.concat([df, result_df], axis=1) + + @requires_validate_and_build + def total_cost( + self, + prompt_text: Optional[str], + response: Optional[str], + ) -> dict[str, float]: + """Run the DAG on all terminal paths and report total costs per path. + If no prompt/response are provided, uses empty strings.""" + + ctx = EvalContext( + prompt_text=prompt_text or "", + response=response or "", + ) + + path_costs: dict[str, float] = {} + + def _dfs(node_name: str, accumulated: float, path: list[str]) -> None: + node = self._nodes[node_name] + total = accumulated + node.cost(ctx) + if isinstance(node, Output): + path_costs[" -> ".join(path + [node_name])] = total + return + for target in node.all_routes(): + _dfs(target, total, path + [node_name]) + + for root in self._root_nodes: + _dfs(root, 0.0, []) + + return path_costs + + @requires_validate_and_build + def visualize(self) -> None: + """Render the DAG structure with ascii.""" + print(f"EvaluatorDAG: {self.name!r}") + print("=" * (len(self.name) + 18)) + for node_name in self._ordered: + node = self._nodes[node_name] + node_type = type(node).__name__ + if isinstance(node, Output): + route_str = f" → verdict='{node.name}'" + elif isinstance(node, Gate): + route_str = f" → True:{node.routes_true} False:{node.routes_false}" + else: + route_str = f" → {node.routes}" + print(f" [{node_type:10s}] {node_name}{route_str}") + + +class DAGAnnotator(Annotator): + """Annotator that executes a DAG.""" + + def __init__(self, uid: str, dag: EvaluatorDAG) -> None: + super().__init__(uid) + self.dag = dag + + def translate_prompt( + self, + prompt: TextPrompt | ChatPrompt, + response: SUTResponse, + ) -> EvalContext: + prompt_text = ( + prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt) + ) + return EvalContext( + prompt_text=prompt_text, + response=response.text, + ) + + def annotate(self, annotation_request: EvalContext) -> Output: + return self.dag.run(annotation_request) + + def translate_response( + self, + request: EvalContext, + response: Output, + ) -> SafetyAnnotation: + """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" + # TODO: unclear whether SafetyAnnotation is the right standardized output + return SafetyAnnotation(is_safe=response.is_safe()) diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py new file mode 100644 index 0000000..4725837 --- /dev/null +++ b/src/modelplane/evaluator/nodes.py @@ -0,0 +1,164 @@ +""" +Node types for the EvaluatorDAG pipeline. + +Class hierarchy: + + EvaluatorNode (ABC) + ├── Gate (binary test; routes on True/False) + ├── Enricher (transforms context; routes unconditionally) + ├── Scorer (produces a float score; routes unconditionally) + └── Arbiter (produces output) + Output (terminal node; carries a verdict value) +""" + +from abc import ABC, abstractmethod +from typing import Any, Optional + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.outputs import Output + + +class EvaluatorDAGNode(ABC): + def __init__( + self, + name: str, + routes_true: Optional[list[str | Output]] = None, + routes_false: Optional[list[str | Output]] = None, + routes: Optional[list[str | Output]] = None, + ) -> None: + self.name = name + self.routes_true = routes_true or [] + self.routes_false = routes_false or [] + self.routes = routes or [] + self._was_run = False + self._output: Any = None + self.validate() + + @abstractmethod + def _run(self, ctx: EvalContext) -> Any: + pass + + def run(self, ctx: EvalContext) -> Any: + """Execute the node and return its output.""" + if self._was_run: + return self._output + self._output = self._run(ctx) + self._was_run = True + return self._output + + def cost(self, ctx: EvalContext) -> float: + """Return the estimated cost of running this node. Default is 0.0; + override for LLM calls or other expensive operations.""" + return 0.0 + + @property + def output(self) -> Any: + """Return the output of this node after it has been run.""" + if not self._was_run: + raise ValueError(f"Node {self.name!r} has not been run yet.") + return self._output + + def __repr__(self) -> str: + return f"{self.name!r}: ({self.__class__.__name__})" + + def all_routes(self) -> list[str]: + """Return a list of all route targets from this node.""" + return [ + *[r if isinstance(r, str) else r.name for r in self.routes_true], + *[r if isinstance(r, str) else r.name for r in self.routes_false], + *[r if isinstance(r, str) else r.name for r in self.routes], + ] + + def next_nodes(self) -> list[str | Output]: + """Given a node output value, return the list of next node names to activate.""" + if not self._was_run: + raise ValueError("Cannot get next nodes before running the node.") + if isinstance(self, Gate): + return self.routes_true if self.output else self.routes_false + else: + return self.routes + + def validate(self) -> None: + """Validate that the node's routing configuration is consistent with its type.""" + # validate that routes with Outputs only have one Output + for route_list in [self.routes_true, self.routes_false, self.routes]: + output_routes = [r for r in route_list if isinstance(r, Output)] + if len(output_routes) > 1: + raise ValueError( + f"{self!r} has multiple Output routes {output_routes}, which is not allowed." + ) + + +def _validate_binary_routes(node: EvaluatorDAGNode) -> None: + if not node.routes_true or not node.routes_false: + raise ValueError(f"{node!r} requires both routes_true and routes_false") + if node.routes: + raise ValueError( + f"{node!r} should not have routes= (use routes_true= / routes_false=)" + ) + + +def _validate_unary_routes(node: EvaluatorDAGNode) -> None: + if not node.routes: + raise ValueError(f"{node!r} requires routes=") + if node.routes_true or node.routes_false: + raise ValueError( + f"{node!r} should not have routes_true= / routes_false= (use routes=)" + ) + + +def _validate_terminal(node: EvaluatorDAGNode) -> None: + if node.routes_true or node.routes_false or node.routes: + raise ValueError(f"{node!r} is terminal and cannot have routing kwargs") + + +class Gate(EvaluatorDAGNode): + """Binary test node.""" + + @abstractmethod + def _run(self, ctx: EvalContext) -> bool: + """Return True or False to indicate which route to take from this gate.""" + + def validate(self) -> None: + super().validate() + _validate_binary_routes(self) + + +class Enricher(EvaluatorDAGNode): + """Context transformation node.""" + + @abstractmethod + def _run(self, ctx: EvalContext) -> str: + """Return a new string representing the enriched context.""" + + def validate(self) -> None: + super().validate() + _validate_unary_routes(self) + + +class Scorer(EvaluatorDAGNode): + """Scoring node. Produces a float score from the (possibly enriched) context.""" + + @abstractmethod + def _run(self, ctx: EvalContext) -> float: + """Return a score for the current context.""" + + def validate(self) -> None: + super().validate() + _validate_unary_routes(self) + + +class Arbiter(EvaluatorDAGNode): + """Takes context and returns an Output indicating the final verdict (based on routes).""" + + @abstractmethod + def _run(self, ctx: EvalContext) -> Output: + """Return an Output indicating the final verdict.""" + + def validate(self) -> None: + super().validate() + _validate_terminal(self) + + @abstractmethod + def outputs(self) -> list[Output]: + """Return the list of possible Output verdicts this Arbiter can return.""" diff --git a/src/modelplane/evaluator/outputs.py b/src/modelplane/evaluator/outputs.py new file mode 100644 index 0000000..ad496d6 --- /dev/null +++ b/src/modelplane/evaluator/outputs.py @@ -0,0 +1,30 @@ +from abc import abstractmethod + + +class Output: + @abstractmethod + def is_safe(self) -> bool: + pass + + @property + def name(self) -> str: + return self.__class__.__name__ + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class Violating(Output): + + def is_safe(self) -> bool: + return False + + +class NonViolating(Output): + + def is_safe(self) -> bool: + return True + + +VIOLATING = Violating() +NONVIOLATING = NonViolating() diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py new file mode 100644 index 0000000..2016c4d --- /dev/null +++ b/tests/unit/evaluator/conftest.py @@ -0,0 +1,105 @@ +"""Shared mock node implementations and helpers for evaluator tests.""" + +import pytest + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.dag import EvaluatorDAG +from modelplane.evaluator.outputs import NONVIOLATING, VIOLATING, Output +from .mocks import ( + AlwaysFalse, + AlwaysNonViolating, + AlwaysTrue, + AlwaysViolating, + FixedScorer, + LLMEnricher, + LLMEnricher, + LowerCaseScorer, + LowerCaser, + PromptLengthGate, + ThresholdArbiter, + UpperCaseScorer, + UpperCaser, +) + +TRUE_BRANCH: list[str | Output] = ["true_branch"] +FALSE_BRANCH: list[str | Output] = ["false_branch"] +DEFAULT_BRANCH: list[str | Output] = ["next_node"] +SCORE1 = 1.0 +SCORE2 = 2.0 + + +@pytest.fixture +def always_true_gate() -> AlwaysTrue: + return AlwaysTrue( + name="always_true", routes_true=TRUE_BRANCH, routes_false=FALSE_BRANCH + ) + + +@pytest.fixture +def always_false_gate() -> AlwaysFalse: + return AlwaysFalse( + name="always_false", routes_true=TRUE_BRANCH, routes_false=FALSE_BRANCH + ) + + +@pytest.fixture +def lower_caser() -> LowerCaser: + return LowerCaser(name="lower_caser", routes=DEFAULT_BRANCH) + + +@pytest.fixture +def score_1() -> FixedScorer: + return FixedScorer(name="score_1", value=SCORE1, routes=DEFAULT_BRANCH) + + +@pytest.fixture +def score_2() -> FixedScorer: + return FixedScorer(name="score_2", value=SCORE2, routes=DEFAULT_BRANCH) + + +@pytest.fixture +def costly_enricher() -> LLMEnricher: + return LLMEnricher(name="costly_enricher", routes=DEFAULT_BRANCH) + + +@pytest.fixture +def sample_ctx() -> EvalContext: + return EvalContext(prompt_text="Hello, world!", response="This is a response.") + + +@pytest.fixture +def always_violating() -> AlwaysViolating: + return AlwaysViolating(name="always_violating") + + +@pytest.fixture +def always_non_violating() -> AlwaysNonViolating: + return AlwaysNonViolating(name="always_non_violating") + + +@pytest.fixture +def threshold_arbiter() -> ThresholdArbiter: + return ThresholdArbiter(name="threshold_arbiter", threshold=1.5) + + +@pytest.fixture +def simple_dag(): + return ( + EvaluatorDAG("simple", outputs=[NONVIOLATING, VIOLATING]) + .add_node( + PromptLengthGate( + name="prompt_parity", + routes_true=["lower_caser"], + routes_false=["upper_caser"], + ) + ) + .add_node( + LowerCaser(name="lower_caser", routes=["lower_scorer", "upper_scorer"]) + ) + .add_node( + UpperCaser(name="upper_caser", routes=["lower_scorer", "upper_scorer"]) + ) + .add_node(LowerCaseScorer(name="lower_scorer", routes=["threshold_arbiter"])) + .add_node(UpperCaseScorer(name="upper_scorer", routes=["threshold_arbiter"])) + .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) + ) diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py new file mode 100644 index 0000000..84bf39a --- /dev/null +++ b/tests/unit/evaluator/mocks.py @@ -0,0 +1,109 @@ +import math + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.nodes import Arbiter, Enricher, Gate, Scorer +from modelplane.evaluator.outputs import NONVIOLATING, VIOLATING, Output + + +class PassthroughGate(Gate): + ROUTE_TO_TAKE: bool + + def _run(self, ctx: EvalContext) -> bool: + return self.ROUTE_TO_TAKE + + +class AlwaysTrue(PassthroughGate): + ROUTE_TO_TAKE = True + + +class AlwaysFalse(PassthroughGate): + ROUTE_TO_TAKE = False + + +class PromptLengthGate(Gate): + def _run(self, ctx: EvalContext) -> bool: + return len(ctx.prompt_text) % 2 == 0 + + +class LowerCaser(Enricher): + """Enriches by returning the response lowercased.""" + + def _run(self, ctx: EvalContext) -> str: + return ctx.response.lower() + + +class UpperCaser(Enricher): + """Enriches by returning the response uppercased.""" + + def _run(self, ctx: EvalContext) -> str: + return ctx.response.upper() + + +class LLMEnricher(Enricher): + + def cost(self, ctx: EvalContext) -> float: + return len(ctx.prompt_text) + len(ctx.response) + + def _run(self, ctx: EvalContext) -> str: + return ctx.response + + +class FixedScorer(Scorer): + """Returns a fixed float score regardless of context.""" + + def __init__(self, name: str, value: float, **kwargs): + super().__init__(name, **kwargs) + self.value = value + + def _run(self, ctx: EvalContext) -> float: + return self.value + + +class LowerCaseScorer(Scorer): + """Scores based on the percentage of lowercase characters in the response.""" + + def _run(self, ctx: EvalContext) -> float: + if not ctx.response: + return 0.0 + num_lower = sum(1 for c in ctx.response if c.islower()) + return num_lower / len(ctx.response) + + +class UpperCaseScorer(Scorer): + """Scores based on the percentage of uppercase characters in the response.""" + + def _run(self, ctx: EvalContext) -> float: + if not ctx.response: + return 0.0 + num_upper = sum(1 for c in ctx.response if c.isupper()) + return num_upper / len(ctx.response) + + +class AlwaysViolating(Arbiter): + def _run(self, ctx: EvalContext) -> Output: + return VIOLATING + + def outputs(self) -> list[Output]: + return [VIOLATING] + + +class AlwaysNonViolating(Arbiter): + def _run(self, ctx: EvalContext) -> Output: + return NONVIOLATING + + def outputs(self) -> list[Output]: + return [NONVIOLATING] + + +class ThresholdArbiter(Arbiter): + def __init__(self, name: str, threshold: float, **kwargs): + super().__init__(name, **kwargs) + self.threshold = threshold + + def _run(self, ctx: EvalContext) -> Output: + scores = ctx.parent_outputs() + score = sum(scores) / len(scores) + return VIOLATING if score >= self.threshold else NONVIOLATING + + def outputs(self) -> list[Output]: + return [VIOLATING, NONVIOLATING] diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py new file mode 100644 index 0000000..f12f1bf --- /dev/null +++ b/tests/unit/evaluator/test_dag.py @@ -0,0 +1,9 @@ +"""Unit tests for EvaluatorDAG construction, validation, execution, and visualization.""" + + +import pandas as pd + + +def test_dag_run(simple_dag, sample_ctx): + result = simple_dag.run(sample_ctx) + assert result.is_safe() diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py new file mode 100644 index 0000000..e766223 --- /dev/null +++ b/tests/unit/evaluator/test_nodes.py @@ -0,0 +1,65 @@ +"""Unit tests for individual EvaluatorDAGNode subclasses.""" + +import pytest + +from .conftest import DEFAULT_BRANCH, FALSE_BRANCH, SCORE1, SCORE2, TRUE_BRANCH + + +def test_error_getting_next_nodes_before_run(sample_ctx, lower_caser): + with pytest.raises( + ValueError, match="Cannot get next nodes before running the node." + ): + lower_caser.next_nodes() + + +def test_true_routes_to_true_branch(sample_ctx, always_true_gate): + output = always_true_gate.run(sample_ctx) + assert output + assert always_true_gate.next_nodes() == TRUE_BRANCH + + +def test_false_routes_to_false_branch(sample_ctx, always_false_gate): + output = always_false_gate.run(sample_ctx) + assert not output + assert always_false_gate.next_nodes() == FALSE_BRANCH + + +def test_output_cached(sample_ctx, lower_caser): + output1 = lower_caser.run(sample_ctx) + assert lower_caser._was_run + assert lower_caser._output == output1 + + +def test_lower_caser(sample_ctx, lower_caser): + output = lower_caser.run(sample_ctx) + assert output == sample_ctx.response.lower() + assert lower_caser.next_nodes() == DEFAULT_BRANCH + + +def test_fixed_scorer(sample_ctx, score_1): + output = score_1.run(sample_ctx) + assert output == SCORE1 + assert score_1.next_nodes() == DEFAULT_BRANCH + + +def test_consistent_arbiters( + sample_ctx, score_1, score_2, always_violating, always_non_violating +): + parent_outputs = {score_1.name: SCORE1, score_2.name: SCORE2} + sample_ctx.set_parent_outputs(parent_outputs) + output = always_violating.run(sample_ctx) + assert not output.is_safe() + output = always_non_violating.run(sample_ctx) + assert output.is_safe() + + +def test_threshold_arbiter_true(sample_ctx, threshold_arbiter): + sample_ctx.set_parent_outputs({"parent0": SCORE2, "parent1": SCORE2}) + output = threshold_arbiter.run(sample_ctx) + assert not output.is_safe() + + +def test_threshold_arbiter_false(sample_ctx, threshold_arbiter): + sample_ctx.set_parent_outputs({"parent0": SCORE1, "parent1": SCORE1}) + output = threshold_arbiter.run(sample_ctx) + assert output.is_safe() From f8534d906489651d63eb66868c3bde093f7ff595 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Fri, 3 Apr 2026 10:46:00 -0400 Subject: [PATCH 02/19] Fixes. --- src/modelplane/evaluator/dag.py | 9 +++++--- src/modelplane/evaluator/nodes.py | 34 +++++++----------------------- tests/unit/evaluator/__init__.py | 0 tests/unit/evaluator/mocks.py | 24 ++++++++++----------- tests/unit/evaluator/test_dag.py | 21 +++++++++++++++++- tests/unit/evaluator/test_nodes.py | 23 ++++---------------- 6 files changed, 49 insertions(+), 62 deletions(-) create mode 100644 tests/unit/evaluator/__init__.py diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index fa46306..b0f0542 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -51,6 +51,8 @@ class EvaluatorDAG: results_df = dag.run_dataframe(df) """ + DATAFRAME_OUTPUT_COL = "output" + def __init__(self, name: str, outputs: list[Output]) -> None: self.name = name self._nodes: dict[str, EvaluatorDAGNode] = {} @@ -166,7 +168,6 @@ def run( while active_nodes: next_active = [] for node_name in active_nodes: - print("Running node:", node_name) # set parent outputs in context for this node ctx.set_parent_outputs( { @@ -182,7 +183,7 @@ def run( return output outputs[node_name] = output # see which nodes to activate next based on output and routing - next_active.extend(node.next_nodes()) + next_active.extend(node.next_nodes(output)) active_nodes = next_active raise ValueError("DAG execution completed without reaching an Output node.") @@ -204,7 +205,9 @@ def _run_row(row: Any) -> Output: records = [_run_row(row) for _, row in df.iterrows()] - result_df = pd.DataFrame(records, index=df.index) + result_df = pd.DataFrame( + {self.DATAFRAME_OUTPUT_COL: [r.name for r in records]}, index=df.index + ) return pd.concat([df, result_df], axis=1) @requires_validate_and_build diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index 4725837..0e72d33 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -30,34 +30,18 @@ def __init__( self.routes_true = routes_true or [] self.routes_false = routes_false or [] self.routes = routes or [] - self._was_run = False - self._output: Any = None self.validate() @abstractmethod - def _run(self, ctx: EvalContext) -> Any: - pass - def run(self, ctx: EvalContext) -> Any: """Execute the node and return its output.""" - if self._was_run: - return self._output - self._output = self._run(ctx) - self._was_run = True - return self._output + pass def cost(self, ctx: EvalContext) -> float: """Return the estimated cost of running this node. Default is 0.0; override for LLM calls or other expensive operations.""" return 0.0 - @property - def output(self) -> Any: - """Return the output of this node after it has been run.""" - if not self._was_run: - raise ValueError(f"Node {self.name!r} has not been run yet.") - return self._output - def __repr__(self) -> str: return f"{self.name!r}: ({self.__class__.__name__})" @@ -69,12 +53,10 @@ def all_routes(self) -> list[str]: *[r if isinstance(r, str) else r.name for r in self.routes], ] - def next_nodes(self) -> list[str | Output]: - """Given a node output value, return the list of next node names to activate.""" - if not self._was_run: - raise ValueError("Cannot get next nodes before running the node.") + def next_nodes(self, output: Any) -> list[str | Output]: + """Given the node's output value, return the list of next node names to activate.""" if isinstance(self, Gate): - return self.routes_true if self.output else self.routes_false + return self.routes_true if output else self.routes_false else: return self.routes @@ -116,7 +98,7 @@ class Gate(EvaluatorDAGNode): """Binary test node.""" @abstractmethod - def _run(self, ctx: EvalContext) -> bool: + def run(self, ctx: EvalContext) -> bool: """Return True or False to indicate which route to take from this gate.""" def validate(self) -> None: @@ -128,7 +110,7 @@ class Enricher(EvaluatorDAGNode): """Context transformation node.""" @abstractmethod - def _run(self, ctx: EvalContext) -> str: + def run(self, ctx: EvalContext) -> str: """Return a new string representing the enriched context.""" def validate(self) -> None: @@ -140,7 +122,7 @@ class Scorer(EvaluatorDAGNode): """Scoring node. Produces a float score from the (possibly enriched) context.""" @abstractmethod - def _run(self, ctx: EvalContext) -> float: + def run(self, ctx: EvalContext) -> float: """Return a score for the current context.""" def validate(self) -> None: @@ -152,7 +134,7 @@ class Arbiter(EvaluatorDAGNode): """Takes context and returns an Output indicating the final verdict (based on routes).""" @abstractmethod - def _run(self, ctx: EvalContext) -> Output: + def run(self, ctx: EvalContext) -> Output: """Return an Output indicating the final verdict.""" def validate(self) -> None: diff --git a/tests/unit/evaluator/__init__.py b/tests/unit/evaluator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index 84bf39a..398e44e 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -1,5 +1,3 @@ -import math - from modelplane.evaluator.context import EvalContext from modelplane.evaluator.nodes import Arbiter, Enricher, Gate, Scorer from modelplane.evaluator.outputs import NONVIOLATING, VIOLATING, Output @@ -8,7 +6,7 @@ class PassthroughGate(Gate): ROUTE_TO_TAKE: bool - def _run(self, ctx: EvalContext) -> bool: + def run(self, ctx: EvalContext) -> bool: return self.ROUTE_TO_TAKE @@ -21,21 +19,21 @@ class AlwaysFalse(PassthroughGate): class PromptLengthGate(Gate): - def _run(self, ctx: EvalContext) -> bool: + def run(self, ctx: EvalContext) -> bool: return len(ctx.prompt_text) % 2 == 0 class LowerCaser(Enricher): """Enriches by returning the response lowercased.""" - def _run(self, ctx: EvalContext) -> str: + def run(self, ctx: EvalContext) -> str: return ctx.response.lower() class UpperCaser(Enricher): """Enriches by returning the response uppercased.""" - def _run(self, ctx: EvalContext) -> str: + def run(self, ctx: EvalContext) -> str: return ctx.response.upper() @@ -44,7 +42,7 @@ class LLMEnricher(Enricher): def cost(self, ctx: EvalContext) -> float: return len(ctx.prompt_text) + len(ctx.response) - def _run(self, ctx: EvalContext) -> str: + def run(self, ctx: EvalContext) -> str: return ctx.response @@ -55,14 +53,14 @@ def __init__(self, name: str, value: float, **kwargs): super().__init__(name, **kwargs) self.value = value - def _run(self, ctx: EvalContext) -> float: + def run(self, ctx: EvalContext) -> float: return self.value class LowerCaseScorer(Scorer): """Scores based on the percentage of lowercase characters in the response.""" - def _run(self, ctx: EvalContext) -> float: + def run(self, ctx: EvalContext) -> float: if not ctx.response: return 0.0 num_lower = sum(1 for c in ctx.response if c.islower()) @@ -72,7 +70,7 @@ def _run(self, ctx: EvalContext) -> float: class UpperCaseScorer(Scorer): """Scores based on the percentage of uppercase characters in the response.""" - def _run(self, ctx: EvalContext) -> float: + def run(self, ctx: EvalContext) -> float: if not ctx.response: return 0.0 num_upper = sum(1 for c in ctx.response if c.isupper()) @@ -80,7 +78,7 @@ def _run(self, ctx: EvalContext) -> float: class AlwaysViolating(Arbiter): - def _run(self, ctx: EvalContext) -> Output: + def run(self, ctx: EvalContext) -> Output: return VIOLATING def outputs(self) -> list[Output]: @@ -88,7 +86,7 @@ def outputs(self) -> list[Output]: class AlwaysNonViolating(Arbiter): - def _run(self, ctx: EvalContext) -> Output: + def run(self, ctx: EvalContext) -> Output: return NONVIOLATING def outputs(self) -> list[Output]: @@ -100,7 +98,7 @@ def __init__(self, name: str, threshold: float, **kwargs): super().__init__(name, **kwargs) self.threshold = threshold - def _run(self, ctx: EvalContext) -> Output: + def run(self, ctx: EvalContext) -> Output: scores = ctx.parent_outputs() score = sum(scores) / len(scores) return VIOLATING if score >= self.threshold else NONVIOLATING diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index f12f1bf..07cc87c 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -1,9 +1,28 @@ """Unit tests for EvaluatorDAG construction, validation, execution, and visualization.""" - import pandas as pd def test_dag_run(simple_dag, sample_ctx): result = simple_dag.run(sample_ctx) assert result.is_safe() + + +def test_dag_run_with_dataframe(simple_dag): + # "hello world" (space lowers avg below threshold) → safe + # "helloworld" (no space, avg = 0.5 = threshold) → unsafe + # Alternate even/odd prompt lengths to exercise both enricher paths. + df = pd.DataFrame( + { + "prompt_text": ["a", "ab", "abc", "abcd"], # odd, even, odd, even + "sut_response": ["hello world", "helloworld", "hello world", "helloworld"], + } + ) + result_df = simple_dag.run_dataframe(df) + + assert len(result_df) == len(df) + assert "prompt_text" in result_df.columns + assert "sut_response" in result_df.columns + verdicts = result_df[simple_dag.DATAFRAME_OUTPUT_COL].tolist() + expected_verdicts = ["NonViolating", "Violating", "NonViolating", "Violating"] + assert verdicts == expected_verdicts diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py index e766223..e046d19 100644 --- a/tests/unit/evaluator/test_nodes.py +++ b/tests/unit/evaluator/test_nodes.py @@ -1,45 +1,30 @@ """Unit tests for individual EvaluatorDAGNode subclasses.""" -import pytest - from .conftest import DEFAULT_BRANCH, FALSE_BRANCH, SCORE1, SCORE2, TRUE_BRANCH -def test_error_getting_next_nodes_before_run(sample_ctx, lower_caser): - with pytest.raises( - ValueError, match="Cannot get next nodes before running the node." - ): - lower_caser.next_nodes() - - def test_true_routes_to_true_branch(sample_ctx, always_true_gate): output = always_true_gate.run(sample_ctx) assert output - assert always_true_gate.next_nodes() == TRUE_BRANCH + assert always_true_gate.next_nodes(output) == TRUE_BRANCH def test_false_routes_to_false_branch(sample_ctx, always_false_gate): output = always_false_gate.run(sample_ctx) assert not output - assert always_false_gate.next_nodes() == FALSE_BRANCH - - -def test_output_cached(sample_ctx, lower_caser): - output1 = lower_caser.run(sample_ctx) - assert lower_caser._was_run - assert lower_caser._output == output1 + assert always_false_gate.next_nodes(output) == FALSE_BRANCH def test_lower_caser(sample_ctx, lower_caser): output = lower_caser.run(sample_ctx) assert output == sample_ctx.response.lower() - assert lower_caser.next_nodes() == DEFAULT_BRANCH + assert lower_caser.next_nodes(output) == DEFAULT_BRANCH def test_fixed_scorer(sample_ctx, score_1): output = score_1.run(sample_ctx) assert output == SCORE1 - assert score_1.next_nodes() == DEFAULT_BRANCH + assert score_1.next_nodes(output) == DEFAULT_BRANCH def test_consistent_arbiters( From a966cabb4677eac380c7e129024632d498f306bc Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Fri, 3 Apr 2026 14:14:06 -0400 Subject: [PATCH 03/19] Refactor EvalContext to standardize prompt attribute naming --- src/modelplane/evaluator/context.py | 4 ++-- src/modelplane/evaluator/dag.py | 16 ++++++++-------- tests/unit/evaluator/conftest.py | 2 +- tests/unit/evaluator/mocks.py | 4 ++-- tests/unit/evaluator/test_dag.py | 8 ++++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/modelplane/evaluator/context.py b/src/modelplane/evaluator/context.py index c30e661..7df9faa 100644 --- a/src/modelplane/evaluator/context.py +++ b/src/modelplane/evaluator/context.py @@ -4,8 +4,8 @@ class EvalContext: """Context state passed around during DAG execution.""" - def __init__(self, prompt_text: str, response: str) -> None: - self.prompt_text = prompt_text + def __init__(self, prompt: str, response: str) -> None: + self.prompt = prompt self.response = response self._parent_outputs = {} diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index b0f0542..f6a8485 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -46,7 +46,7 @@ class EvaluatorDAG: .add_node(arbiter) ) # run single - result = dag.run(prompt_uid="123", prompt_text="...", response="...") + result = dag.run(prompt_uid="123", prompt="...", response="...") # run batch results_df = dag.run_dataframe(df) """ @@ -191,14 +191,14 @@ def run( def run_dataframe( self, df: pd.DataFrame, - prompt_text_col: str = "prompt_text", - response_col: str = "sut_response", + prompt_col: str = "prompt", + response_col: str = "response", ) -> pd.DataFrame: """Run the DAG over every row of a DataFrame.""" def _run_row(row: Any) -> Output: ctx = EvalContext( - prompt_text=str(row[prompt_text_col]), + prompt=str(row[prompt_col]), response=str(row[response_col]), ) return self.run(ctx) @@ -213,14 +213,14 @@ def _run_row(row: Any) -> Output: @requires_validate_and_build def total_cost( self, - prompt_text: Optional[str], + prompt: Optional[str], response: Optional[str], ) -> dict[str, float]: """Run the DAG on all terminal paths and report total costs per path. If no prompt/response are provided, uses empty strings.""" ctx = EvalContext( - prompt_text=prompt_text or "", + prompt=prompt or "", response=response or "", ) @@ -269,11 +269,11 @@ def translate_prompt( prompt: TextPrompt | ChatPrompt, response: SUTResponse, ) -> EvalContext: - prompt_text = ( + prompt_str = ( prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt) ) return EvalContext( - prompt_text=prompt_text, + prompt=prompt_str, response=response.text, ) diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index 2016c4d..8fa44c8 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -64,7 +64,7 @@ def costly_enricher() -> LLMEnricher: @pytest.fixture def sample_ctx() -> EvalContext: - return EvalContext(prompt_text="Hello, world!", response="This is a response.") + return EvalContext(prompt="Hello, world!", response="This is a response.") @pytest.fixture diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index 398e44e..f405486 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -20,7 +20,7 @@ class AlwaysFalse(PassthroughGate): class PromptLengthGate(Gate): def run(self, ctx: EvalContext) -> bool: - return len(ctx.prompt_text) % 2 == 0 + return len(ctx.prompt) % 2 == 0 class LowerCaser(Enricher): @@ -40,7 +40,7 @@ def run(self, ctx: EvalContext) -> str: class LLMEnricher(Enricher): def cost(self, ctx: EvalContext) -> float: - return len(ctx.prompt_text) + len(ctx.response) + return len(ctx.prompt) + len(ctx.response) def run(self, ctx: EvalContext) -> str: return ctx.response diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 07cc87c..3532045 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -14,15 +14,15 @@ def test_dag_run_with_dataframe(simple_dag): # Alternate even/odd prompt lengths to exercise both enricher paths. df = pd.DataFrame( { - "prompt_text": ["a", "ab", "abc", "abcd"], # odd, even, odd, even - "sut_response": ["hello world", "helloworld", "hello world", "helloworld"], + "prompt": ["a", "ab", "abc", "abcd"], # odd, even, odd, even + "response": ["hello world", "helloworld", "hello world", "helloworld"], } ) result_df = simple_dag.run_dataframe(df) assert len(result_df) == len(df) - assert "prompt_text" in result_df.columns - assert "sut_response" in result_df.columns + assert "prompt" in result_df.columns + assert "response" in result_df.columns verdicts = result_df[simple_dag.DATAFRAME_OUTPUT_COL].tolist() expected_verdicts = ["NonViolating", "Violating", "NonViolating", "Violating"] assert verdicts == expected_verdicts From dcd1d0ea491706c8537dd87178ea8367fd6b7c52 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Fri, 3 Apr 2026 14:40:26 -0400 Subject: [PATCH 04/19] Nice visualization plus other fixes. --- Dockerfile.jupyter | 3 +- pyproject.toml | 1 + src/modelplane/evaluator/dag.py | 249 ++++++++++++++++++++++++++------ uv.lock | 11 ++ 4 files changed, 222 insertions(+), 42 deletions(-) diff --git a/Dockerfile.jupyter b/Dockerfile.jupyter index d5138c6..271bfdc 100644 --- a/Dockerfile.jupyter +++ b/Dockerfile.jupyter @@ -1,13 +1,14 @@ FROM python:3.12-slim ENV PATH="/root/.local/bin:$PATH" +ENV PYTHONPATH="/app/flightpaths/flights" # Used for the notebook server WORKDIR /app # pipx needed for uv installation script # ssh client needed for installing private modelbench dependencies # git needed dvc -RUN apt-get update && apt-get install -y pipx openssh-client git && \ +RUN apt-get update && apt-get install -y pipx openssh-client git graphviz && \ pipx install uv COPY pyproject.toml uv.lock README.md ./ diff --git a/pyproject.toml b/pyproject.toml index 305a024..e9b4f4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "scikit-learn>=1.5.0,<2.0.0", "pandas>=2.2.2,<4", "modelbench @ git+https://github.com/mlcommons/modelbench.git", + "graphviz>=0.20,<1", ] [project.scripts] diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index f6a8485..a7475e1 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -1,24 +1,21 @@ """DAGAnnotator and EvaluatorDAG implementation.""" import collections -from dataclasses import dataclass, field import functools +import os +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from typing import Any, Optional import pandas as pd +from modelgauge.annotation import SafetyAnnotation from modelgauge.annotator import Annotator - -from modelplane.evaluator.context import EvalContext -from modelplane.evaluator.nodes import ( - Arbiter, - EvaluatorDAGNode, - Gate, - Output, -) from modelgauge.prompt import ChatPrompt, TextPrompt from modelgauge.prompt_formatting import format_chat from modelgauge.sut import SUTResponse -from modelgauge.annotation import SafetyAnnotation + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.nodes import Arbiter, EvaluatorDAGNode, Gate, Output def requires_validate_and_build(method): @@ -110,6 +107,8 @@ def _validate_and_build(self) -> None: in_degree: dict[str, int] = {n: 0 for n in self._nodes} for route in all_routes.values(): for t in route: + if t in self._outputs: + continue in_degree[t] += 1 root_nodes = [n for n in self._nodes if in_degree[n] == 0] @@ -119,6 +118,8 @@ def _validate_and_build(self) -> None: current = queue.popleft() ordered.append(current) for child in all_routes.get(current, []): + if child in self._outputs: + continue in_degree[child] -= 1 if in_degree[child] == 0: queue.append(child) @@ -155,44 +156,54 @@ def _validate_and_build(self) -> None: self._root_nodes = root_nodes self._ordered = ordered - @requires_validate_and_build - def run( - self, - ctx: EvalContext, - ) -> Output: - """ - Execute the DAG on a single prompt/response. - """ + def _run_traced( + self, ctx: EvalContext + ) -> tuple[Output, dict[str, Any], set[tuple[str, str]]]: + """Execute the DAG and return (final output, node outputs, traversed edges).""" active_nodes = self._root_nodes - outputs: dict[str, Any] = {} + node_outputs: dict[str, Any] = {} + traversed_edges: set[tuple[str, str]] = set() while active_nodes: next_active = [] for node_name in active_nodes: - # set parent outputs in context for this node ctx.set_parent_outputs( { - pred: outputs[pred] + pred: node_outputs[pred] for pred in self._predecessors[node_name] - if pred in outputs + if pred in node_outputs } ) - # run the node node = self._nodes[node_name] output = node.run(ctx) if isinstance(output, Output): - return output - outputs[node_name] = output - # see which nodes to activate next based on output and routing - next_active.extend(node.next_nodes(output)) + traversed_edges.add((node_name, output.name)) + return output, node_outputs, traversed_edges + node_outputs[node_name] = output + for target in node.next_nodes(output): + t = target if isinstance(target, str) else target.name + traversed_edges.add((node_name, t)) + if isinstance(target, Output): + return target, node_outputs, traversed_edges + next_active.append(t) active_nodes = next_active raise ValueError("DAG execution completed without reaching an Output node.") + @requires_validate_and_build + def run( + self, + ctx: EvalContext, + ) -> Output: + """Execute the DAG on a single prompt/response.""" + output, _, _ = self._run_traced(ctx) + return output + @requires_validate_and_build def run_dataframe( self, df: pd.DataFrame, prompt_col: str = "prompt", response_col: str = "response", + n_jobs: int = 1, ) -> pd.DataFrame: """Run the DAG over every row of a DataFrame.""" @@ -203,7 +214,14 @@ def _run_row(row: Any) -> Output: ) return self.run(ctx) - records = [_run_row(row) for _, row in df.iterrows()] + rows = [row for _, row in df.iterrows()] + + if n_jobs == 1: + records = [_run_row(row) for row in rows] + else: + max_workers = os.cpu_count() if n_jobs == -1 else n_jobs + with ThreadPoolExecutor(max_workers=max_workers) as executor: + records = list(executor.map(_run_row, rows)) result_df = pd.DataFrame( {self.DATAFRAME_OUTPUT_COL: [r.name for r in records]}, index=df.index @@ -241,20 +259,169 @@ def _dfs(node_name: str, accumulated: float, path: list[str]) -> None: return path_costs @requires_validate_and_build - def visualize(self) -> None: - """Render the DAG structure with ascii.""" - print(f"EvaluatorDAG: {self.name!r}") - print("=" * (len(self.name) + 18)) - for node_name in self._ordered: - node = self._nodes[node_name] - node_type = type(node).__name__ - if isinstance(node, Output): - route_str = f" → verdict='{node.name}'" - elif isinstance(node, Gate): - route_str = f" → True:{node.routes_true} False:{node.routes_false}" + def visualize( + self, + node_outputs: Optional[dict[str, Any]] = None, + traversed_edges: Optional[set[tuple[str, str]]] = None, + final_output: Optional[Output] = None, + ): + """Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline. + + When node_outputs/traversed_edges/final_output are provided (via visualize_run), + the hot path is highlighted and each node shows its output value. + """ + import graphviz + from IPython.display import Image + + traced = node_outputs is not None + + def _format_output(value: Any) -> str: + if isinstance(value, float): + return f"{value:.3g}" + s = str(value) + return s if len(s) <= 30 else s[:27] + "..." + + _NODE_STYLES: dict[type, dict] = { + Gate: {"shape": "diamond", "style": "filled", "fillcolor": "#d0e8f5"}, + Arbiter: {"shape": "box", "style": "filled", "fillcolor": "#c8e6c9"}, + Output: {"shape": "ellipse", "style": "filled", "fillcolor": "#fff9c4"}, + } + _DEFAULT_STYLE = {"shape": "box", "style": "filled", "fillcolor": "#ffe0b2"} + _DIM = { + "style": "filled", + "fillcolor": "#f0f0f0", + "color": "#bbbbbb", + "fontcolor": "#aaaaaa", + } + + dot = graphviz.Digraph(name=self.name) + dot.attr( + label=self.name, + labelloc="t", + fontsize="13", + fontname="Helvetica", + rankdir="TB", + ranksep="0.5", + nodesep="0.4", + ) + dot.attr("node", fontname="Helvetica", fontsize="11") + dot.attr("edge", fontname="Helvetica", fontsize="10") + + # implicit input node pinned to the top + top = graphviz.Digraph() + top.attr(rank="min") + top.node( + "__input__", + "prompt\nresponse", + shape="box", + style="dashed", + fillcolor="white", + color="#888888", + fontcolor="#555555", + ) + dot.subgraph(top) + + # output terminal nodes pinned to the bottom + bottom = graphviz.Digraph() + bottom.attr(rank="max") + for output_name, output_node in self._outputs.items(): + attrs = dict(_NODE_STYLES[Output]) + if traced: + if output_node is final_output: + attrs["penwidth"] = "2.5" + else: + attrs = dict(_DIM, shape="ellipse") + bottom.node(output_name, **attrs) + dot.subgraph(bottom) + + # processing nodes + for node_name, node in self._nodes.items(): + base_style = next( + (s for t, s in _NODE_STYLES.items() if isinstance(node, t)), + _DEFAULT_STYLE, + ) + if traced and node_name not in node_outputs: + attrs = dict(_DIM, shape=base_style.get("shape", "box")) + label = node_name else: - route_str = f" → {node.routes}" - print(f" [{node_type:10s}] {node_name}{route_str}") + attrs = dict(base_style) + if traced: + raw = node_outputs[node_name] # type: ignore[index] + label = f"{node_name}\n{_format_output(raw)}" + else: + label = node_name + dot.node(node_name, label, **attrs) + + # dashed edges from implicit input to root nodes + for root in self._root_nodes: + dot.edge( + "__input__", root, style="dashed", color="#888888", arrowhead="open" + ) + + # edges between processing nodes + for node_name, node in self._nodes.items(): + if isinstance(node, Gate): + for target in node.routes_true: + t = target if isinstance(target, str) else target.name + hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator] + dot.edge( + node_name, + t, + label=" True", + color="#2e7d32" if hot else "#cccccc", + fontcolor="#2e7d32" if hot else "#cccccc", + penwidth="2" if hot and traced else "1", + ) + for target in node.routes_false: + t = target if isinstance(target, str) else target.name + hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator] + dot.edge( + node_name, + t, + label=" False", + color="#c62828" if hot else "#cccccc", + fontcolor="#c62828" if hot else "#cccccc", + penwidth="2" if hot and traced else "1", + ) + elif isinstance(node, Arbiter): + for output in node.outputs(): + hot = not traced or (node_name, output.name) in traversed_edges # type: ignore[operator] + dot.edge( + node_name, + output.name, + color="#555555" if hot else "#cccccc", + penwidth="2" if hot and traced else "1", + ) + else: + for target in node.routes: + t = target if isinstance(target, str) else target.name + hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator] + dot.edge( + node_name, + t, + color="#555555" if hot else "#cccccc", + penwidth="2" if hot and traced else "1", + ) + + try: + return Image(dot.pipe(format="png")) + except graphviz.ExecutableNotFound: + raise RuntimeError( + "Graphviz system binaries not found. Install them with:\n" + " macOS: brew install graphviz\n" + " Ubuntu: apt-get install graphviz\n" + " conda: conda install graphviz" + ) from None + + @requires_validate_and_build + def visualize_run(self, ctx: EvalContext): + """Run the DAG on ctx and return a PNG with the executed path highlighted.""" + final_output, node_outputs, traversed_edges = self._run_traced(ctx) + return self.visualize( + node_outputs=node_outputs, + traversed_edges=traversed_edges, + final_output=final_output, + ) class DAGAnnotator(Annotator): diff --git a/uv.lock b/uv.lock index 7b2de6e..8864835 100644 --- a/uv.lock +++ b/uv.lock @@ -1975,6 +1975,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/16/a4cf06adbc711bd364a73ce043b0b08d8fa5aae3df11b6ee4248bcdad2e0/graphql_relay-3.2.0-py3-none-any.whl", hash = "sha256:c9b22bd28b170ba1fe674c74384a8ff30a76c8e26f88ac3aa1584dd3179953e5", size = 16940, upload-time = "2022-04-16T11:03:43.895Z" }, ] +[[package]] +name = "graphviz" +version = "0.21" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/b3/3ac91e9be6b761a4b30d66ff165e54439dcd48b83f4e20d644867215f6ca/graphviz-0.21.tar.gz", hash = "sha256:20743e7183be82aaaa8ad6c93f8893c923bd6658a04c32ee115edb3c8a835f78", size = 200434, upload-time = "2025-06-15T09:35:05.824Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/4c/e0ce1ef95d4000ebc1c11801f9b944fa5910ecc15b5e351865763d8657f8/graphviz-0.21-py3-none-any.whl", hash = "sha256:54f33de9f4f911d7e84e4191749cac8cc5653f815b06738c54db9a15ab8b1e42", size = 47300, upload-time = "2025-06-15T09:35:04.433Z" }, +] + [[package]] name = "greenlet" version = "3.3.1" @@ -3266,6 +3275,7 @@ source = { editable = "." } dependencies = [ { name = "click" }, { name = "dvc", extra = ["gs"] }, + { name = "graphviz" }, { name = "jsonlines" }, { name = "jupyter" }, { name = "jupyterlab-git" }, @@ -3298,6 +3308,7 @@ test = [ requires-dist = [ { name = "click", specifier = ">=8,<9" }, { name = "dvc", extras = ["gs"], specifier = ">=3.60,<4" }, + { name = "graphviz", specifier = ">=0.20,<1" }, { name = "jsonlines", specifier = ">=4,<5" }, { name = "jupyter", specifier = ">=1,<2" }, { name = "jupyterlab-git" }, From 0d110aefde3ed8d9b3337f8c1510a9c7ad9e7c89 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 10:04:41 -0400 Subject: [PATCH 05/19] Clean up docs and error message. --- src/modelplane/evaluator/dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index a7475e1..45db53e 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -34,10 +34,10 @@ class EvaluatorDAG: refusal_gate = MyRefusalGate("RefusalGate", routes_true=[NONVIOLATING], routes_false=["NonRefusal"]) eval_non_refusal = MyNonRefusalEvaluator("NonRefusal", routes=["Arbiter"]) - arbiter = MyArbiter("Arbiter", routes_true=[VIOLATING], routes_false=[NONVIOLATING]) + arbiter = MyArbiter("Arbiter") dag = ( - EvaluatorDAG("refusal_evaluator", outputs=[NONVIOLATING, VIOLATING]) + EvaluatorDAG("refusal_gated_safety_evaluator", outputs=[NONVIOLATING, VIOLATING]) .add_node(refusal_gate) .add_node(eval_non_refusal) .add_node(arbiter) @@ -67,7 +67,7 @@ def add_node( if node.name in self._all_names(): raise ValueError( - f"A different node named {node.name!r} is already registered." + f"A different node named {node.name} is already registered." ) self._nodes[node.name] = node self._validated = False From 17c1fe3aad8af6c2191af6ce42b76238583f6d40 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 10:04:56 -0400 Subject: [PATCH 06/19] Make it harder to modify routes after instantiation. --- src/modelplane/evaluator/nodes.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index 0e72d33..602d688 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -27,11 +27,23 @@ def __init__( routes: Optional[list[str | Output]] = None, ) -> None: self.name = name - self.routes_true = routes_true or [] - self.routes_false = routes_false or [] - self.routes = routes or [] + self._routes_true: tuple[str | Output] = tuple(routes_true or []) + self._routes_false: tuple[str | Output] = tuple(routes_false or []) + self._routes: tuple[str | Output] = tuple(routes or []) self.validate() + @property + def routes_true(self) -> tuple[str | Output]: + return self._routes_true + + @property + def routes_false(self) -> tuple[str | Output]: + return self._routes_false + + @property + def routes(self) -> tuple[str | Output]: + return self._routes + @abstractmethod def run(self, ctx: EvalContext) -> Any: """Execute the node and return its output.""" @@ -53,8 +65,8 @@ def all_routes(self) -> list[str]: *[r if isinstance(r, str) else r.name for r in self.routes], ] - def next_nodes(self, output: Any) -> list[str | Output]: - """Given the node's output value, return the list of next node names to activate.""" + def next_nodes(self, output: Any) -> tuple[str | Output]: + """Given the node's output value, return the tuple of next node names to activate.""" if isinstance(self, Gate): return self.routes_true if output else self.routes_false else: From c384469177f1642ee6a6a4ef57a247eb68789671 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 10:30:28 -0400 Subject: [PATCH 07/19] Generalize outputs. --- src/modelplane/evaluator/dag.py | 15 ++++++++++++++- src/modelplane/evaluator/outputs.py | 27 +++++++++++---------------- tests/unit/evaluator/conftest.py | 28 ++++++++++++++-------------- tests/unit/evaluator/mocks.py | 18 +++++++++--------- tests/unit/evaluator/test_dag.py | 4 ++-- tests/unit/evaluator/test_nodes.py | 16 +++++++--------- 6 files changed, 57 insertions(+), 51 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 45db53e..0625604 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -59,6 +59,11 @@ def __init__(self, name: str, outputs: list[Output]) -> None: self._predecessors: dict[str, list[str]] = collections.defaultdict(list) self._outputs = {output.name: output for output in outputs} + @property + def outputs(self) -> list[Output]: + """Return the list of Output nodes declared in the DAG constructor.""" + return list(self._outputs.values()) + def add_node( self, node: EvaluatorDAGNode, @@ -447,6 +452,14 @@ def translate_prompt( def annotate(self, annotation_request: EvalContext) -> Output: return self.dag.run(annotation_request) + +def SafetyDAGAnnotator(DAGAnnotator): + + def __init__(self, uid: str, dag: EvaluatorDAG) -> None: + super().__init__(uid, dag) + if not all(isinstance(o, Safety) for o in dag.outputs): + raise ValueError("All outputs of the DAG must be of type Safety.") + def translate_response( self, request: EvalContext, @@ -454,4 +467,4 @@ def translate_response( ) -> SafetyAnnotation: """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" # TODO: unclear whether SafetyAnnotation is the right standardized output - return SafetyAnnotation(is_safe=response.is_safe()) + return SafetyAnnotation(is_safe=response.is_safe) diff --git a/src/modelplane/evaluator/outputs.py b/src/modelplane/evaluator/outputs.py index ad496d6..203821e 100644 --- a/src/modelplane/evaluator/outputs.py +++ b/src/modelplane/evaluator/outputs.py @@ -2,29 +2,24 @@ class Output: - @abstractmethod - def is_safe(self) -> bool: - pass - @property + @abstractmethod def name(self) -> str: - return self.__class__.__name__ + """Return a string name for this output, used for routing and debugging.""" def __repr__(self) -> str: - return self.__class__.__name__ - + return f"{self.name} ({self.__class__.__name__})" -class Violating(Output): - def is_safe(self) -> bool: - return False +class Safety(Output): + def __init__(self, is_safe: bool) -> None: + self.is_safe = is_safe -class NonViolating(Output): - - def is_safe(self) -> bool: - return True + @property + def name(self) -> str: + return "SAFE" if self.is_safe else "UNSAFE" -VIOLATING = Violating() -NONVIOLATING = NonViolating() +SAFE = Safety(is_safe=True) +UNSAFE = Safety(is_safe=False) diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index 8fa44c8..d63bcd2 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -4,26 +4,26 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.dag import EvaluatorDAG -from modelplane.evaluator.outputs import NONVIOLATING, VIOLATING, Output +from modelplane.evaluator.outputs import SAFE, UNSAFE, Output + from .mocks import ( AlwaysFalse, - AlwaysNonViolating, + AlwaysSafe, AlwaysTrue, - AlwaysViolating, + AlwaysUnsafe, FixedScorer, LLMEnricher, - LLMEnricher, - LowerCaseScorer, LowerCaser, + LowerCaseScorer, PromptLengthGate, ThresholdArbiter, - UpperCaseScorer, UpperCaser, + UpperCaseScorer, ) -TRUE_BRANCH: list[str | Output] = ["true_branch"] -FALSE_BRANCH: list[str | Output] = ["false_branch"] -DEFAULT_BRANCH: list[str | Output] = ["next_node"] +TRUE_BRANCH: tuple[str | Output] = ("true_branch",) +FALSE_BRANCH: tuple[str | Output] = ("false_branch",) +DEFAULT_BRANCH: tuple[str | Output] = ("next_node",) SCORE1 = 1.0 SCORE2 = 2.0 @@ -68,13 +68,13 @@ def sample_ctx() -> EvalContext: @pytest.fixture -def always_violating() -> AlwaysViolating: - return AlwaysViolating(name="always_violating") +def always_unsafe() -> AlwaysUnsafe: + return AlwaysUnsafe(name="always_unsafe") @pytest.fixture -def always_non_violating() -> AlwaysNonViolating: - return AlwaysNonViolating(name="always_non_violating") +def always_safe() -> AlwaysSafe: + return AlwaysSafe(name="always_safe") @pytest.fixture @@ -85,7 +85,7 @@ def threshold_arbiter() -> ThresholdArbiter: @pytest.fixture def simple_dag(): return ( - EvaluatorDAG("simple", outputs=[NONVIOLATING, VIOLATING]) + EvaluatorDAG("simple", outputs=[SAFE, UNSAFE]) .add_node( PromptLengthGate( name="prompt_parity", diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index f405486..4ad6ba7 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -1,6 +1,6 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.nodes import Arbiter, Enricher, Gate, Scorer -from modelplane.evaluator.outputs import NONVIOLATING, VIOLATING, Output +from modelplane.evaluator.outputs import SAFE, UNSAFE, Output class PassthroughGate(Gate): @@ -77,20 +77,20 @@ def run(self, ctx: EvalContext) -> float: return num_upper / len(ctx.response) -class AlwaysViolating(Arbiter): +class AlwaysUnsafe(Arbiter): def run(self, ctx: EvalContext) -> Output: - return VIOLATING + return UNSAFE def outputs(self) -> list[Output]: - return [VIOLATING] + return [UNSAFE] -class AlwaysNonViolating(Arbiter): +class AlwaysSafe(Arbiter): def run(self, ctx: EvalContext) -> Output: - return NONVIOLATING + return SAFE def outputs(self) -> list[Output]: - return [NONVIOLATING] + return [SAFE] class ThresholdArbiter(Arbiter): @@ -101,7 +101,7 @@ def __init__(self, name: str, threshold: float, **kwargs): def run(self, ctx: EvalContext) -> Output: scores = ctx.parent_outputs() score = sum(scores) / len(scores) - return VIOLATING if score >= self.threshold else NONVIOLATING + return UNSAFE if score >= self.threshold else SAFE def outputs(self) -> list[Output]: - return [VIOLATING, NONVIOLATING] + return [UNSAFE, SAFE] diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 3532045..29544ef 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -5,7 +5,7 @@ def test_dag_run(simple_dag, sample_ctx): result = simple_dag.run(sample_ctx) - assert result.is_safe() + assert result.name == "SAFE" def test_dag_run_with_dataframe(simple_dag): @@ -24,5 +24,5 @@ def test_dag_run_with_dataframe(simple_dag): assert "prompt" in result_df.columns assert "response" in result_df.columns verdicts = result_df[simple_dag.DATAFRAME_OUTPUT_COL].tolist() - expected_verdicts = ["NonViolating", "Violating", "NonViolating", "Violating"] + expected_verdicts = ["SAFE", "UNSAFE", "SAFE", "UNSAFE"] assert verdicts == expected_verdicts diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py index e046d19..1aaae07 100644 --- a/tests/unit/evaluator/test_nodes.py +++ b/tests/unit/evaluator/test_nodes.py @@ -27,24 +27,22 @@ def test_fixed_scorer(sample_ctx, score_1): assert score_1.next_nodes(output) == DEFAULT_BRANCH -def test_consistent_arbiters( - sample_ctx, score_1, score_2, always_violating, always_non_violating -): +def test_consistent_arbiters(sample_ctx, score_1, score_2, always_unsafe, always_safe): parent_outputs = {score_1.name: SCORE1, score_2.name: SCORE2} sample_ctx.set_parent_outputs(parent_outputs) - output = always_violating.run(sample_ctx) - assert not output.is_safe() - output = always_non_violating.run(sample_ctx) - assert output.is_safe() + output = always_unsafe.run(sample_ctx) + assert output.name == "UNSAFE" + output = always_safe.run(sample_ctx) + assert output.name == "SAFE" def test_threshold_arbiter_true(sample_ctx, threshold_arbiter): sample_ctx.set_parent_outputs({"parent0": SCORE2, "parent1": SCORE2}) output = threshold_arbiter.run(sample_ctx) - assert not output.is_safe() + assert output.name == "UNSAFE" def test_threshold_arbiter_false(sample_ctx, threshold_arbiter): sample_ctx.set_parent_outputs({"parent0": SCORE1, "parent1": SCORE1}) output = threshold_arbiter.run(sample_ctx) - assert output.is_safe() + assert output.name == "SAFE" From 602ae21b2490fe51ca7daffb8081f55b54af23e2 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 14:23:43 -0400 Subject: [PATCH 08/19] Improved errors and tests and various bug fixes. --- src/modelplane/evaluator/dag.py | 153 ++++++++++++++++++------------ src/modelplane/evaluator/nodes.py | 8 +- tests/unit/evaluator/conftest.py | 63 +++++++++++- tests/unit/evaluator/mocks.py | 71 +++++++++++++- 4 files changed, 224 insertions(+), 71 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 0625604..da3b62a 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -3,6 +3,7 @@ import collections import functools import os +from itertools import product from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from typing import Any, Optional @@ -130,27 +131,18 @@ def _validate_and_build(self) -> None: queue.append(child) if len(ordered) != len(self._nodes): - # missing nodes - missing = set(self._nodes) - set(ordered) - raise ValueError(f"Graph contains a cycle. Missing nodes: {missing}") + nodes_in_cycle = set(self._nodes) - set(ordered) + raise ValueError(f"DAG contains a cycle. Nodes in cycle: {nodes_in_cycle}") - # check all terminal nodes are Output nodes + # check all terminal Arbiter nodes have correct outputs terminal_nodes = [n for n in self._nodes if not all_routes.get(n)] for terminal in terminal_nodes: entity = all_named_entities[terminal] - if isinstance(entity, Output) and terminal not in self._outputs: - raise ValueError( - f"Terminal Output node {terminal!r} is not declared as an output in the DAG constructor." - ) - elif isinstance(entity, Arbiter): + if isinstance(entity, Arbiter): if any(o.name not in self._outputs for o in entity.outputs()): raise ValueError( f"Terminal Arbiter node {terminal!r} has output(s) that are not declared as outputs in the DAG constructor." ) - else: - raise ValueError( - f"Terminal node {terminal!r} is not an Output or Arbiter node." - ) # get predecessors for name, node in self._nodes.items(): @@ -165,32 +157,31 @@ def _run_traced( self, ctx: EvalContext ) -> tuple[Output, dict[str, Any], set[tuple[str, str]]]: """Execute the DAG and return (final output, node outputs, traversed edges).""" - active_nodes = self._root_nodes node_outputs: dict[str, Any] = {} traversed_edges: set[tuple[str, str]] = set() - while active_nodes: - next_active = [] - for node_name in active_nodes: - ctx.set_parent_outputs( - { - pred: node_outputs[pred] - for pred in self._predecessors[node_name] - if pred in node_outputs - } - ) - node = self._nodes[node_name] - output = node.run(ctx) - if isinstance(output, Output): - traversed_edges.add((node_name, output.name)) - return output, node_outputs, traversed_edges - node_outputs[node_name] = output - for target in node.next_nodes(output): - t = target if isinstance(target, str) else target.name - traversed_edges.add((node_name, t)) - if isinstance(target, Output): - return target, node_outputs, traversed_edges - next_active.append(t) - active_nodes = next_active + reachable: set[str] = set(self._root_nodes) + for node_name in self._ordered: + if node_name not in reachable: + continue + ctx.set_parent_outputs( + { + pred: node_outputs[pred] + for pred in self._predecessors[node_name] + if pred in node_outputs + } + ) + node = self._nodes[node_name] + output = node.run(ctx) + node_outputs[node_name] = output + if isinstance(output, Output): + traversed_edges.add((node_name, output.name)) + return output, node_outputs, traversed_edges + for target in node.next_nodes(output): + t = target if isinstance(target, str) else target.name + traversed_edges.add((node_name, t)) + if isinstance(target, Output): + return target, node_outputs, traversed_edges + reachable.add(t) raise ValueError("DAG execution completed without reaching an Output node.") @requires_validate_and_build @@ -234,32 +225,62 @@ def _run_row(row: Any) -> Output: return pd.concat([df, result_df], axis=1) @requires_validate_and_build - def total_cost( - self, - prompt: Optional[str], - response: Optional[str], - ) -> dict[str, float]: - """Run the DAG on all terminal paths and report total costs per path. - If no prompt/response are provided, uses empty strings.""" - - ctx = EvalContext( - prompt=prompt or "", - response=response or "", - ) + def total_cost(self, ctx: Optional[EvalContext] = None) -> float: + """Run the DAG on ctx and return the total cost of the executed path.""" + if ctx is None: + ctx = EvalContext(prompt="", response="") + _, node_outputs, _ = self._run_traced(ctx) + total = 0.0 + for node_name in node_outputs: + node = self._nodes[node_name] + total += node.cost(ctx) + return total + @requires_validate_and_build + def total_costs(self) -> dict[str, float]: + """Run the DAG on all terminal paths and report total costs per path.""" + ctx = EvalContext(prompt="", response="") + gates = [name for name, node in self._nodes.items() if isinstance(node, Gate)] path_costs: dict[str, float] = {} - def _dfs(node_name: str, accumulated: float, path: list[str]) -> None: - node = self._nodes[node_name] - total = accumulated + node.cost(ctx) - if isinstance(node, Output): - path_costs[" -> ".join(path + [node_name])] = total - return - for target in node.all_routes(): - _dfs(target, total, path + [node_name]) + for combo in product([True, False], repeat=len(gates)): + gate_outcomes = dict(zip(gates, combo)) + reachable: set[str] = set(self._root_nodes) + path: list[str] = [] + total = 0.0 + terminal_outputs: list[str] = [] - for root in self._root_nodes: - _dfs(root, 0.0, []) + for node_name in self._ordered: + if node_name not in reachable: + continue + node = self._nodes[node_name] + total += node.cost(ctx) + path.append(node_name) + if isinstance(node, Gate): + targets = ( + node.routes_true + if gate_outcomes[node_name] + else node.routes_false + ) + elif isinstance(node, Arbiter): + terminal_outputs = [o.name for o in node.outputs()] + targets = [] + else: + targets = node.routes + for target in targets: + if isinstance(target, Output): + terminal_outputs = [target.name] + else: + reachable.add( + target if isinstance(target, str) else target.name + ) + + base_path = " -> ".join(path) + if terminal_outputs: + for output_name in terminal_outputs: + path_costs[f"{base_path} -> {output_name}"] = total + else: + path_costs[base_path] = total return path_costs @@ -345,14 +366,22 @@ def _format_output(value: Any) -> str: (s for t, s in _NODE_STYLES.items() if isinstance(node, t)), _DEFAULT_STYLE, ) - if traced and node_name not in node_outputs: + node_was_active = (node_outputs is not None and node_name in node_outputs) or ( + traversed_edges is not None + and any(src == node_name for src, _ in traversed_edges) + ) + if traced and not node_was_active: attrs = dict(_DIM, shape=base_style.get("shape", "box")) label = node_name else: attrs = dict(base_style) if traced: - raw = node_outputs[node_name] # type: ignore[index] - label = f"{node_name}\n{_format_output(raw)}" + if node_name in node_outputs: + raw = node_outputs[node_name] # type: ignore[index] + label = f"{node_name}\n{_format_output(raw)}" + else: + label = node_name + attrs["penwidth"] = "2.5" else: label = node_name dot.node(node_name, label, **attrs) diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index 602d688..9d8e4f3 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -12,7 +12,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, Sequence from modelplane.evaluator.context import EvalContext from modelplane.evaluator.outputs import Output @@ -22,9 +22,9 @@ class EvaluatorDAGNode(ABC): def __init__( self, name: str, - routes_true: Optional[list[str | Output]] = None, - routes_false: Optional[list[str | Output]] = None, - routes: Optional[list[str | Output]] = None, + routes_true: Optional[Sequence[str | Output]] = None, + routes_false: Optional[Sequence[str | Output]] = None, + routes: Optional[Sequence[str | Output]] = None, ) -> None: self.name = name self._routes_true: tuple[str | Output] = tuple(routes_true or []) diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index d63bcd2..d20ef93 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -11,12 +11,14 @@ AlwaysSafe, AlwaysTrue, AlwaysUnsafe, + BadArbiter, FixedScorer, LLMEnricher, LowerCaser, LowerCaseScorer, PromptLengthGate, ThresholdArbiter, + UnexpectedArbiter, UpperCaser, UpperCaseScorer, ) @@ -24,6 +26,7 @@ TRUE_BRANCH: tuple[str | Output] = ("true_branch",) FALSE_BRANCH: tuple[str | Output] = ("false_branch",) DEFAULT_BRANCH: tuple[str | Output] = ("next_node",) +BAD_BRANCH: tuple[str | Output] = ("undefined_node",) SCORE1 = 1.0 SCORE2 = 2.0 @@ -35,6 +38,13 @@ def always_true_gate() -> AlwaysTrue: ) +@pytest.fixture +def bad_gate() -> AlwaysTrue: + return AlwaysTrue( + name="bad_gate", routes_true=BAD_BRANCH, routes_false=FALSE_BRANCH + ) + + @pytest.fixture def always_false_gate() -> AlwaysFalse: return AlwaysFalse( @@ -64,7 +74,7 @@ def costly_enricher() -> LLMEnricher: @pytest.fixture def sample_ctx() -> EvalContext: - return EvalContext(prompt="Hello, world!", response="This is a response.") + return EvalContext(prompt="Hello, world", response="This is a response.") @pytest.fixture @@ -86,10 +96,17 @@ def threshold_arbiter() -> ThresholdArbiter: def simple_dag(): return ( EvaluatorDAG("simple", outputs=[SAFE, UNSAFE]) + .add_node( + AlwaysTrue( + name="always_true", + routes_true=["lower_caser", "prompt_parity"], + routes_false=[SAFE], + ) + ) .add_node( PromptLengthGate( name="prompt_parity", - routes_true=["lower_caser"], + routes_true=[UNSAFE], routes_false=["upper_caser"], ) ) @@ -103,3 +120,45 @@ def simple_dag(): .add_node(UpperCaseScorer(name="upper_scorer", routes=["threshold_arbiter"])) .add_node(ThresholdArbiter(name="threshold_arbiter", threshold=0.5)) ) + + +@pytest.fixture() +def bad_dag_with_cycle(): + return ( + EvaluatorDAG("cyclic", outputs=[SAFE, UNSAFE]) + .add_node( + AlwaysTrue( + name="node1", + routes_true=["node2"], + routes_false=["node3"], + ) + ) + .add_node( + AlwaysTrue( + name="node2", + routes_true=["node3"], + routes_false=["node1"], + ) + ) + .add_node( + AlwaysTrue( + name="node3", + routes_true=[SAFE], + routes_false=[UNSAFE], + ) + ) + ) + + +@pytest.fixture +def bad_dag_with_undefined_output(simple_dag): + bad_arbiter = UnexpectedArbiter(name="arbiter") + simple_dag.add_node(bad_arbiter) + return simple_dag + + +@pytest.fixture +def bad_dag_with_bad_arbiter(): + dag = EvaluatorDAG("test", outputs=[SAFE, UNSAFE]) + dag.add_node(BadArbiter(name="bad_arbiter")) + return dag diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index 4ad6ba7..ac680e9 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -13,15 +13,24 @@ def run(self, ctx: EvalContext) -> bool: class AlwaysTrue(PassthroughGate): ROUTE_TO_TAKE = True + def cost(self, ctx: EvalContext) -> float: + return 0.1 + class AlwaysFalse(PassthroughGate): ROUTE_TO_TAKE = False + def cost(self, ctx: EvalContext) -> float: + return 0.2 + class PromptLengthGate(Gate): def run(self, ctx: EvalContext) -> bool: return len(ctx.prompt) % 2 == 0 + def cost(self, ctx: EvalContext) -> float: + return 0.3 + class LowerCaser(Enricher): """Enriches by returning the response lowercased.""" @@ -29,6 +38,9 @@ class LowerCaser(Enricher): def run(self, ctx: EvalContext) -> str: return ctx.response.lower() + def cost(self, ctx: EvalContext) -> float: + return 0.4 + class UpperCaser(Enricher): """Enriches by returning the response uppercased.""" @@ -36,15 +48,18 @@ class UpperCaser(Enricher): def run(self, ctx: EvalContext) -> str: return ctx.response.upper() + def cost(self, ctx: EvalContext) -> float: + return 0.5 + class LLMEnricher(Enricher): - def cost(self, ctx: EvalContext) -> float: - return len(ctx.prompt) + len(ctx.response) - def run(self, ctx: EvalContext) -> str: return ctx.response + def cost(self, ctx: EvalContext) -> float: + return 0.6 + class FixedScorer(Scorer): """Returns a fixed float score regardless of context.""" @@ -56,6 +71,9 @@ def __init__(self, name: str, value: float, **kwargs): def run(self, ctx: EvalContext) -> float: return self.value + def cost(self, ctx: EvalContext) -> float: + return 0.7 + class LowerCaseScorer(Scorer): """Scores based on the percentage of lowercase characters in the response.""" @@ -66,6 +84,9 @@ def run(self, ctx: EvalContext) -> float: num_lower = sum(1 for c in ctx.response if c.islower()) return num_lower / len(ctx.response) + def cost(self, ctx: EvalContext) -> float: + return 0.8 + class UpperCaseScorer(Scorer): """Scores based on the percentage of uppercase characters in the response.""" @@ -76,6 +97,9 @@ def run(self, ctx: EvalContext) -> float: num_upper = sum(1 for c in ctx.response if c.isupper()) return num_upper / len(ctx.response) + def cost(self, ctx: EvalContext) -> float: + return 0.9 + class AlwaysUnsafe(Arbiter): def run(self, ctx: EvalContext) -> Output: @@ -84,6 +108,9 @@ def run(self, ctx: EvalContext) -> Output: def outputs(self) -> list[Output]: return [UNSAFE] + def cost(self, ctx: EvalContext) -> float: + return 1.0 + class AlwaysSafe(Arbiter): def run(self, ctx: EvalContext) -> Output: @@ -92,6 +119,9 @@ def run(self, ctx: EvalContext) -> Output: def outputs(self) -> list[Output]: return [SAFE] + def cost(self, ctx: EvalContext) -> float: + return 1.1 + class ThresholdArbiter(Arbiter): def __init__(self, name: str, threshold: float, **kwargs): @@ -105,3 +135,38 @@ def run(self, ctx: EvalContext) -> Output: def outputs(self) -> list[Output]: return [UNSAFE, SAFE] + + def cost(self, ctx: EvalContext) -> float: + return 1.2 + + +class UnexpectedOutput(Output): + @property + def name(self) -> str: + return "UNEXPECTED_OUTPUT" + + +class UnexpectedArbiter(Arbiter): + """An arbiter that returns an output not declared in outputs().""" + + def run(self, ctx: EvalContext) -> Output: + return UnexpectedOutput() + + def outputs(self) -> list[Output]: + return [UnexpectedOutput()] + + def cost(self, ctx: EvalContext) -> float: + return 1.3 + + +class BadArbiter(Arbiter): + """An arbiter that violates the contract by returning a non-Output value.""" + + def run(self, ctx: EvalContext) -> str: + return "safe" + + def outputs(self) -> list[Output]: + return [SAFE] + + def cost(self, ctx: EvalContext) -> float: + return 1.4 From ac4a26d0dfade1760c0c9165244848c3486bbdd0 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 16:01:48 -0400 Subject: [PATCH 09/19] More fixes. --- src/modelplane/evaluator/context.py | 7 +++++-- src/modelplane/evaluator/dag.py | 10 ++++++---- src/modelplane/evaluator/nodes.py | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/modelplane/evaluator/context.py b/src/modelplane/evaluator/context.py index 7df9faa..c40c148 100644 --- a/src/modelplane/evaluator/context.py +++ b/src/modelplane/evaluator/context.py @@ -1,12 +1,15 @@ -from typing import Any +from typing import Any, Optional class EvalContext: """Context state passed around during DAG execution.""" - def __init__(self, prompt: str, response: str) -> None: + def __init__( + self, prompt: str, response: str, metadata: Optional[dict[str, Any]] = None + ) -> None: self.prompt = prompt self.response = response + self.metadata = metadata or {} self._parent_outputs = {} def set_parent_outputs(self, outputs: dict[str, Any]) -> None: diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index da3b62a..7d25d22 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -17,6 +17,7 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.nodes import Arbiter, EvaluatorDAGNode, Gate, Output +from modelplane.evaluator.outputs import Safety def requires_validate_and_build(method): @@ -366,7 +367,9 @@ def _format_output(value: Any) -> str: (s for t, s in _NODE_STYLES.items() if isinstance(node, t)), _DEFAULT_STYLE, ) - node_was_active = (node_outputs is not None and node_name in node_outputs) or ( + node_was_active = ( + node_outputs is not None and node_name in node_outputs + ) or ( traversed_edges is not None and any(src == node_name for src, _ in traversed_edges) ) @@ -482,7 +485,7 @@ def annotate(self, annotation_request: EvalContext) -> Output: return self.dag.run(annotation_request) -def SafetyDAGAnnotator(DAGAnnotator): +class SafetyDAGAnnotator(DAGAnnotator): def __init__(self, uid: str, dag: EvaluatorDAG) -> None: super().__init__(uid, dag) @@ -492,8 +495,7 @@ def __init__(self, uid: str, dag: EvaluatorDAG) -> None: def translate_response( self, request: EvalContext, - response: Output, + response: Safety, ) -> SafetyAnnotation: """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" - # TODO: unclear whether SafetyAnnotation is the right standardized output return SafetyAnnotation(is_safe=response.is_safe) diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index 9d8e4f3..cc7963a 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -122,8 +122,8 @@ class Enricher(EvaluatorDAGNode): """Context transformation node.""" @abstractmethod - def run(self, ctx: EvalContext) -> str: - """Return a new string representing the enriched context.""" + def run(self, ctx: EvalContext) -> Any: + """Return data representing the enriched context.""" def validate(self) -> None: super().validate() From d2fa1f22b5c0b9d22faab2c4b04a65909e32c68d Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 16:34:11 -0400 Subject: [PATCH 10/19] Improved tests. --- src/modelplane/evaluator/dag.py | 36 +++++------- src/modelplane/evaluator/nodes.py | 9 ++- tests/unit/evaluator/conftest.py | 3 +- tests/unit/evaluator/test_dag.py | 88 +++++++++++++++++++++++++++++- tests/unit/evaluator/test_nodes.py | 59 ++++++++++++++++++++ 5 files changed, 170 insertions(+), 25 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 7d25d22..1e6819e 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -277,16 +277,12 @@ def total_costs(self) -> dict[str, float]: ) base_path = " -> ".join(path) - if terminal_outputs: - for output_name in terminal_outputs: - path_costs[f"{base_path} -> {output_name}"] = total - else: - path_costs[base_path] = total + for output_name in terminal_outputs: + path_costs[f"{base_path} -> {output_name}"] = total return path_costs - @requires_validate_and_build - def visualize( + def _visualize( self, node_outputs: Optional[dict[str, Any]] = None, traversed_edges: Optional[set[tuple[str, str]]] = None, @@ -302,12 +298,6 @@ def visualize( traced = node_outputs is not None - def _format_output(value: Any) -> str: - if isinstance(value, float): - return f"{value:.3g}" - s = str(value) - return s if len(s) <= 30 else s[:27] + "..." - _NODE_STYLES: dict[type, dict] = { Gate: {"shape": "diamond", "style": "filled", "fillcolor": "#d0e8f5"}, Arbiter: {"shape": "box", "style": "filled", "fillcolor": "#c8e6c9"}, @@ -379,11 +369,8 @@ def _format_output(value: Any) -> str: else: attrs = dict(base_style) if traced: - if node_name in node_outputs: - raw = node_outputs[node_name] # type: ignore[index] - label = f"{node_name}\n{_format_output(raw)}" - else: - label = node_name + raw = node_outputs[node_name] # type: ignore[index] + label = f"{node_name}\n{node.format_output(raw)}" attrs["penwidth"] = "2.5" else: label = node_name @@ -442,19 +429,24 @@ def _format_output(value: Any) -> str: try: return Image(dot.pipe(format="png")) - except graphviz.ExecutableNotFound: + except graphviz.ExecutableNotFound as e: raise RuntimeError( "Graphviz system binaries not found. Install them with:\n" " macOS: brew install graphviz\n" " Ubuntu: apt-get install graphviz\n" " conda: conda install graphviz" - ) from None + ) from e + + @requires_validate_and_build + def visualize(self): + """Visualize the DAG structure without execution.""" + return self._visualize() @requires_validate_and_build def visualize_run(self, ctx: EvalContext): - """Run the DAG on ctx and return a PNG with the executed path highlighted.""" + """Run the DAG on ctx and return a visualization with the executed path highlighted.""" final_output, node_outputs, traversed_edges = self._run_traced(ctx) - return self.visualize( + return self._visualize( node_outputs=node_outputs, traversed_edges=traversed_edges, final_output=final_output, diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index cc7963a..7e1aadc 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -47,7 +47,7 @@ def routes(self) -> tuple[str | Output]: @abstractmethod def run(self, ctx: EvalContext) -> Any: """Execute the node and return its output.""" - pass + raise NotImplementedError def cost(self, ctx: EvalContext) -> float: """Return the estimated cost of running this node. Default is 0.0; @@ -57,6 +57,13 @@ def cost(self, ctx: EvalContext) -> float: def __repr__(self) -> str: return f"{self.name!r}: ({self.__class__.__name__})" + def format_output(self, output: Any) -> str: + """Convenience method to format the node's output for debugging/visualization.""" + if isinstance(output, float): + return f"{output:.3g}" + s = str(output) + return s if len(s) <= 30 else s[:27] + "..." + def all_routes(self) -> list[str]: """Return a list of all route targets from this node.""" return [ diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index d20ef93..2043b80 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -100,9 +100,10 @@ def simple_dag(): AlwaysTrue( name="always_true", routes_true=["lower_caser", "prompt_parity"], - routes_false=[SAFE], + routes_false=["always_safe"], ) ) + .add_node(AlwaysSafe(name="always_safe")) .add_node( PromptLengthGate( name="prompt_parity", diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 29544ef..059b3fe 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -1,11 +1,54 @@ """Unit tests for EvaluatorDAG construction, validation, execution, and visualization.""" import pandas as pd +import pytest + +from modelplane.evaluator.outputs import SAFE, UNSAFE + + +def test_dag_outputs(simple_dag): + assert simple_dag.outputs == [SAFE, UNSAFE] + + +def test_add_node_with_same_name_as_existing_node(simple_dag, always_true_gate): + always_true_gate.name = next(iter(simple_dag._nodes)) + with pytest.raises(ValueError, match="is already registered"): + simple_dag.add_node(always_true_gate) # same name as existing node + + +def test_add_node_with_same_name_as_output(simple_dag, always_true_gate): + always_true_gate.name = SAFE.name + with pytest.raises(ValueError, match="is already registered"): + simple_dag.add_node(always_true_gate) # same name as existing output + + +def test_add_node_with_undefined_target_node(simple_dag, bad_gate): + simple_dag.add_node(bad_gate) + with pytest.raises(ValueError, match="routes to unregistered node"): + simple_dag._validate_and_build() + + +def test_dag_with_cycle(bad_dag_with_cycle): + with pytest.raises(ValueError, match="DAG contains a cycle"): + bad_dag_with_cycle._validate_and_build() + + +def test_dag_with_undefined_output(bad_dag_with_undefined_output): + with pytest.raises(ValueError, match=r"has output\(s\) that are not declared"): + bad_dag_with_undefined_output._validate_and_build() + + +def test_dag_with_bad_arbiter(bad_dag_with_bad_arbiter, sample_ctx): + with pytest.raises( + ValueError, + match=r"DAG execution completed without reaching an Output node", + ): + bad_dag_with_bad_arbiter.run(sample_ctx) def test_dag_run(simple_dag, sample_ctx): result = simple_dag.run(sample_ctx) - assert result.name == "SAFE" + assert result.name == "UNSAFE" def test_dag_run_with_dataframe(simple_dag): @@ -26,3 +69,46 @@ def test_dag_run_with_dataframe(simple_dag): verdicts = result_df[simple_dag.DATAFRAME_OUTPUT_COL].tolist() expected_verdicts = ["SAFE", "UNSAFE", "SAFE", "UNSAFE"] assert verdicts == expected_verdicts + + +def test_dag_run_with_dataframe_parallel(simple_dag): + df = pd.DataFrame( + { + "prompt": ["a", "ab", "abc", "abcd"], # odd, even, odd, even + "response": ["hello world", "helloworld", "hello world", "helloworld"], + } + ) + result_df = simple_dag.run_dataframe(df, n_jobs=-1) + + assert len(result_df) == len(df) + assert "prompt" in result_df.columns + assert "response" in result_df.columns + verdicts = result_df[simple_dag.DATAFRAME_OUTPUT_COL].tolist() + expected_verdicts = ["SAFE", "UNSAFE", "SAFE", "UNSAFE"] + assert verdicts == expected_verdicts + + +def test_dag_cost_one_path(simple_dag, sample_ctx): + cost = simple_dag.total_cost(sample_ctx) + # lower_caser and prompt_parity are at the same level from always_true + assert cost == 0.8 + cost = simple_dag.total_cost() + assert cost == 0.8 + + +def test_dag_cost_all_paths(simple_dag): + costs = simple_dag.total_costs() + assert costs == pytest.approx( + { + "always_true -> always_safe -> SAFE": 1.2, + "always_true -> lower_caser -> prompt_parity -> lower_scorer -> upper_scorer -> threshold_arbiter -> SAFE": 3.7, + "always_true -> lower_caser -> prompt_parity -> lower_scorer -> upper_scorer -> threshold_arbiter -> UNSAFE": 3.7, + "always_true -> lower_caser -> prompt_parity -> upper_caser -> lower_scorer -> upper_scorer -> threshold_arbiter -> SAFE": 4.2, + "always_true -> lower_caser -> prompt_parity -> upper_caser -> lower_scorer -> upper_scorer -> threshold_arbiter -> UNSAFE": 4.2, + } + ) + + +def test_dag_visualize_runs(simple_dag, sample_ctx): + simple_dag.visualize() + simple_dag.visualize_run(sample_ctx) diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py index 1aaae07..294ea9e 100644 --- a/tests/unit/evaluator/test_nodes.py +++ b/tests/unit/evaluator/test_nodes.py @@ -1,5 +1,11 @@ """Unit tests for individual EvaluatorDAGNode subclasses.""" +import pytest + +from modelplane.evaluator.outputs import SAFE, UNSAFE + +from .mocks import AlwaysTrue, AlwaysUnsafe, LowerCaser + from .conftest import DEFAULT_BRANCH, FALSE_BRANCH, SCORE1, SCORE2, TRUE_BRANCH @@ -46,3 +52,56 @@ def test_threshold_arbiter_false(sample_ctx, threshold_arbiter): sample_ctx.set_parent_outputs({"parent0": SCORE1, "parent1": SCORE1}) output = threshold_arbiter.run(sample_ctx) assert output.name == "SAFE" + + +def test_gate_with_two_outputs(): + with pytest.raises(ValueError, match="has multiple Output routes"): + AlwaysTrue( + name="bad_gate", + routes_true=[SAFE, UNSAFE], + routes_false=FALSE_BRANCH, + ) + + +def test_gate_with_no_true_route(): + with pytest.raises(ValueError, match="requires both routes_true and routes_false"): + AlwaysTrue( + name="bad_gate", + routes_false=FALSE_BRANCH, + ) + + +def test_gate_with_routes(): + with pytest.raises(ValueError, match="should not have routes"): + AlwaysTrue( + name="bad_gate", + routes_true=TRUE_BRANCH, + routes_false=FALSE_BRANCH, + routes=DEFAULT_BRANCH, + ) + + +def test_enricher_with_binary_routes(): + with pytest.raises( + ValueError, match="should not have routes_true= / routes_false=" + ): + LowerCaser( + name="bad_enricher", + routes_true=TRUE_BRANCH, + routes=DEFAULT_BRANCH, + ) + + +def test_enricher_with_no_routes(): + with pytest.raises(ValueError, match="requires routes="): + LowerCaser( + name="bad_enricher", + ) + + +def test_arbiter_with_routes(): + with pytest.raises(ValueError, match="is terminal and cannot have routing kwargs"): + AlwaysUnsafe( + name="bad_arbiter", + routes=DEFAULT_BRANCH, + ) From fb37c3a01afe3347f48dd7d8531b812fcf604868 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Tue, 7 Apr 2026 16:41:15 -0400 Subject: [PATCH 11/19] Skip graphviz test in CI. --- tests/unit/evaluator/conftest.py | 4 ++++ tests/unit/evaluator/test_dag.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index 2043b80..cdb3515 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -1,5 +1,7 @@ """Shared mock node implementations and helpers for evaluator tests.""" +import os + import pytest from modelplane.evaluator.context import EvalContext @@ -30,6 +32,8 @@ SCORE1 = 1.0 SCORE2 = 2.0 +skip_in_ci = pytest.mark.skipif(os.getenv("CI") == "true", reason="skipped in CI") + @pytest.fixture def always_true_gate() -> AlwaysTrue: diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 059b3fe..5897092 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -3,6 +3,8 @@ import pandas as pd import pytest +from .conftest import skip_in_ci + from modelplane.evaluator.outputs import SAFE, UNSAFE @@ -109,6 +111,7 @@ def test_dag_cost_all_paths(simple_dag): ) +@skip_in_ci def test_dag_visualize_runs(simple_dag, sample_ctx): simple_dag.visualize() simple_dag.visualize_run(sample_ctx) From ef3bb959cc92cf64a57a0e190531980b8518005c Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Wed, 8 Apr 2026 15:23:35 -0400 Subject: [PATCH 12/19] Refactor safety things to same module, add annotator arbiter. --- src/modelplane/evaluator/annotator.py | 32 ++++++++++++++++ src/modelplane/evaluator/dag.py | 49 +------------------------ src/modelplane/evaluator/outputs.py | 14 ------- src/modelplane/evaluator/safety.py | 53 +++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 62 deletions(-) create mode 100644 src/modelplane/evaluator/annotator.py create mode 100644 src/modelplane/evaluator/safety.py diff --git a/src/modelplane/evaluator/annotator.py b/src/modelplane/evaluator/annotator.py new file mode 100644 index 0000000..4371a77 --- /dev/null +++ b/src/modelplane/evaluator/annotator.py @@ -0,0 +1,32 @@ +from modelgauge.annotator import Annotator +from modelgauge.prompt import ChatPrompt, TextPrompt +from modelgauge.prompt_formatting import format_chat +from modelgauge.sut import SUTResponse + +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.dag import EvaluatorDAG +from modelplane.evaluator.outputs import Output + + +class DAGAnnotator(Annotator): + """Annotator that executes a DAG.""" + + def __init__(self, uid: str, dag: EvaluatorDAG) -> None: + super().__init__(uid) + self.dag = dag + + def translate_prompt( + self, + prompt: TextPrompt | ChatPrompt, + response: SUTResponse, + ) -> EvalContext: + prompt_str = ( + prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt) + ) + return EvalContext( + prompt=prompt_str, + response=response.text, + ) + + def annotate(self, annotation_request: EvalContext) -> Output: + return self.dag.run(annotation_request) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 1e6819e..ce4fe11 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -3,21 +3,14 @@ import collections import functools import os -from itertools import product from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field +from itertools import product from typing import Any, Optional import pandas as pd -from modelgauge.annotation import SafetyAnnotation -from modelgauge.annotator import Annotator -from modelgauge.prompt import ChatPrompt, TextPrompt -from modelgauge.prompt_formatting import format_chat -from modelgauge.sut import SUTResponse from modelplane.evaluator.context import EvalContext from modelplane.evaluator.nodes import Arbiter, EvaluatorDAGNode, Gate, Output -from modelplane.evaluator.outputs import Safety def requires_validate_and_build(method): @@ -451,43 +444,3 @@ def visualize_run(self, ctx: EvalContext): traversed_edges=traversed_edges, final_output=final_output, ) - - -class DAGAnnotator(Annotator): - """Annotator that executes a DAG.""" - - def __init__(self, uid: str, dag: EvaluatorDAG) -> None: - super().__init__(uid) - self.dag = dag - - def translate_prompt( - self, - prompt: TextPrompt | ChatPrompt, - response: SUTResponse, - ) -> EvalContext: - prompt_str = ( - prompt.text if isinstance(prompt, TextPrompt) else format_chat(prompt) - ) - return EvalContext( - prompt=prompt_str, - response=response.text, - ) - - def annotate(self, annotation_request: EvalContext) -> Output: - return self.dag.run(annotation_request) - - -class SafetyDAGAnnotator(DAGAnnotator): - - def __init__(self, uid: str, dag: EvaluatorDAG) -> None: - super().__init__(uid, dag) - if not all(isinstance(o, Safety) for o in dag.outputs): - raise ValueError("All outputs of the DAG must be of type Safety.") - - def translate_response( - self, - request: EvalContext, - response: Safety, - ) -> SafetyAnnotation: - """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" - return SafetyAnnotation(is_safe=response.is_safe) diff --git a/src/modelplane/evaluator/outputs.py b/src/modelplane/evaluator/outputs.py index 203821e..e086a4c 100644 --- a/src/modelplane/evaluator/outputs.py +++ b/src/modelplane/evaluator/outputs.py @@ -9,17 +9,3 @@ def name(self) -> str: def __repr__(self) -> str: return f"{self.name} ({self.__class__.__name__})" - - -class Safety(Output): - - def __init__(self, is_safe: bool) -> None: - self.is_safe = is_safe - - @property - def name(self) -> str: - return "SAFE" if self.is_safe else "UNSAFE" - - -SAFE = Safety(is_safe=True) -UNSAFE = Safety(is_safe=False) diff --git a/src/modelplane/evaluator/safety.py b/src/modelplane/evaluator/safety.py new file mode 100644 index 0000000..1873c74 --- /dev/null +++ b/src/modelplane/evaluator/safety.py @@ -0,0 +1,53 @@ +from modelgauge.annotation import SafetyAnnotation +from modelgauge.annotator import Annotator, SUTResponse, TextPrompt + +from modelplane.evaluator.annotator import DAGAnnotator +from modelplane.evaluator.context import EvalContext +from modelplane.evaluator.dag import EvaluatorDAG +from modelplane.evaluator.nodes import Arbiter +from modelplane.evaluator.outputs import Output + + +class Safety(Output): + + def __init__(self, is_safe: bool) -> None: + self.is_safe = is_safe + + @property + def name(self) -> str: + return "SAFE" if self.is_safe else "UNSAFE" + + +SAFE = Safety(is_safe=True) +UNSAFE = Safety(is_safe=False) + + +class SafetyDAGAnnotator(DAGAnnotator): + """Implementation of DAGAnnotator that produces a SafetyAnnotation.""" + + def __init__(self, uid: str, dag: EvaluatorDAG) -> None: + super().__init__(uid, dag) + if not all(isinstance(o, Safety) for o in dag.outputs): + raise ValueError("All outputs of the DAG must be of type Safety.") + + def translate_response( + self, + request: EvalContext, + response: Safety, + ) -> SafetyAnnotation: + """Map DAGResult verdict to a SafetyAnnotation (is_safe bool).""" + return SafetyAnnotation(is_safe=response.is_safe) + + +class AnnotatorArbiter(Arbiter): + """Arbiter that outputs SAFE or UNSAFE based on the output of a (safety) Annotator.""" + + def __init__(self, name: str, annotator: Annotator) -> None: + super().__init__(name=name) + self.annotator = annotator + + def run(self, ctx: EvalContext) -> Output: + prompt = TextPrompt(text=ctx.prompt) + response = SUTResponse(text=ctx.response) + annotation = self.annotator.process(prompt, response) + return SAFE if annotation.is_safe else UNSAFE From 96359bf5cc4125e68e335ffbc7d68c928628d9f8 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Wed, 8 Apr 2026 16:06:01 -0400 Subject: [PATCH 13/19] Fix tests. --- tests/unit/evaluator/conftest.py | 3 ++- tests/unit/evaluator/mocks.py | 3 ++- tests/unit/evaluator/test_dag.py | 4 ++-- tests/unit/evaluator/test_nodes.py | 5 ++--- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index cdb3515..a5850f5 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -6,7 +6,8 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.dag import EvaluatorDAG -from modelplane.evaluator.outputs import SAFE, UNSAFE, Output +from modelplane.evaluator.outputs import Output +from modelplane.evaluator.safety import SAFE, UNSAFE from .mocks import ( AlwaysFalse, diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index ac680e9..3bd7b51 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -1,6 +1,7 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.nodes import Arbiter, Enricher, Gate, Scorer -from modelplane.evaluator.outputs import SAFE, UNSAFE, Output +from modelplane.evaluator.outputs import Output +from modelplane.evaluator.safety import SAFE, UNSAFE class PassthroughGate(Gate): diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 5897092..dea2544 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -3,9 +3,9 @@ import pandas as pd import pytest -from .conftest import skip_in_ci +from modelplane.evaluator.safety import SAFE, UNSAFE -from modelplane.evaluator.outputs import SAFE, UNSAFE +from .conftest import skip_in_ci def test_dag_outputs(simple_dag): diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py index 294ea9e..6bd5850 100644 --- a/tests/unit/evaluator/test_nodes.py +++ b/tests/unit/evaluator/test_nodes.py @@ -2,11 +2,10 @@ import pytest -from modelplane.evaluator.outputs import SAFE, UNSAFE - -from .mocks import AlwaysTrue, AlwaysUnsafe, LowerCaser +from modelplane.evaluator.safety import SAFE, UNSAFE from .conftest import DEFAULT_BRANCH, FALSE_BRANCH, SCORE1, SCORE2, TRUE_BRANCH +from .mocks import AlwaysTrue, AlwaysUnsafe, LowerCaser def test_true_routes_to_true_branch(sample_ctx, always_true_gate): From 5f36d1c5f064959e9d1e4e627d200eb80ab4801c Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Thu, 9 Apr 2026 10:38:47 -0400 Subject: [PATCH 14/19] Fix missing method for annotator arbiter. --- src/modelplane/evaluator/safety.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/modelplane/evaluator/safety.py b/src/modelplane/evaluator/safety.py index 1873c74..dcdb30e 100644 --- a/src/modelplane/evaluator/safety.py +++ b/src/modelplane/evaluator/safety.py @@ -46,8 +46,11 @@ def __init__(self, name: str, annotator: Annotator) -> None: super().__init__(name=name) self.annotator = annotator - def run(self, ctx: EvalContext) -> Output: + def run(self, ctx: EvalContext) -> Safety: prompt = TextPrompt(text=ctx.prompt) response = SUTResponse(text=ctx.response) annotation = self.annotator.process(prompt, response) return SAFE if annotation.is_safe else UNSAFE + + def outputs(self) -> list[Safety]: + return [SAFE, UNSAFE] From 39cf4f11ea96dab3870e290ad307b7215f25af1a Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 13 Apr 2026 12:54:02 -0400 Subject: [PATCH 15/19] Refactor things. --- src/modelplane/evaluator/dag.py | 68 +++++++++++++----------------- src/modelplane/evaluator/nodes.py | 47 +++++++-------------- src/modelplane/evaluator/safety.py | 9 +--- 3 files changed, 47 insertions(+), 77 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index ce4fe11..14c42fd 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -27,16 +27,17 @@ class EvaluatorDAG: Usage: - refusal_gate = MyRefusalGate("RefusalGate", routes_true=[NONVIOLATING], routes_false=["NonRefusal"]) + refusal_gate = MyRefusalGate("RefusalGate", routes_true=[Score(value=1)], routes_false=["NonRefusal"]) eval_non_refusal = MyNonRefusalEvaluator("NonRefusal", routes=["Arbiter"]) arbiter = MyArbiter("Arbiter") dag = ( - EvaluatorDAG("refusal_gated_safety_evaluator", outputs=[NONVIOLATING, VIOLATING]) + EvaluatorDAG("refusal_gated_safety_evaluator", output_type=Safety) .add_node(refusal_gate) .add_node(eval_non_refusal) .add_node(arbiter) ) + # run single result = dag.run(prompt_uid="123", prompt="...", response="...") # run batch @@ -45,19 +46,16 @@ class EvaluatorDAG: DATAFRAME_OUTPUT_COL = "output" - def __init__(self, name: str, outputs: list[Output]) -> None: + def __init__(self, name: str, output_type: type) -> None: self.name = name self._nodes: dict[str, EvaluatorDAGNode] = {} self._root_nodes: list[str] = [] self._ordered: list[str] = [] self._validated: bool = False self._predecessors: dict[str, list[str]] = collections.defaultdict(list) - self._outputs = {output.name: output for output in outputs} - - @property - def outputs(self) -> list[Output]: - """Return the list of Output nodes declared in the DAG constructor.""" - return list(self._outputs.values()) + if not issubclass(output_type, Output): + raise ValueError("output_type must be a subclass of Output.") + self._output_type = output_type def add_node( self, @@ -65,7 +63,7 @@ def add_node( ) -> "EvaluatorDAG": """Register a node with its routes.""" - if node.name in self._all_names(): + if node.name in self._nodes: raise ValueError( f"A different node named {node.name} is already registered." ) @@ -73,16 +71,12 @@ def add_node( self._validated = False return self - def _all_names(self) -> dict[str, EvaluatorDAGNode | Output]: - return {**self._nodes, **self._outputs} - def _validate_and_build(self) -> None: """ Validate the DAG: - - All routes reference registered nodes. + - All routes reference registered nodes or instances of the output type. - No cycles. - - All paths lead to an Output node. - - All Output nodes are declared as outputs in the DAG constructor. + - All paths lead to an instance of the output type. Build: - _predecessors: dict mapping node name to list of parent node names (for context during execution) @@ -93,23 +87,24 @@ def _validate_and_build(self) -> None: if self._validated: return - all_named_entities = self._all_names() - # check that all route targets reference registered nodes + # check that all route targets reference registered nodes or instances of the output type for node_name, node in self._nodes.items(): for target in node.all_routes(): - if target not in all_named_entities: + if target not in self._nodes and not isinstance( + target, self._output_type + ): raise ValueError( - f"Node {node_name!r} routes to unregistered node {target!r}." + f"Node {node_name} routes to unregistered node {target} or incompatible output." ) # check for cycles (kahn's algorithm) all_routes = {name: node.all_routes() for name, node in self._nodes.items()} in_degree: dict[str, int] = {n: 0 for n in self._nodes} - for route in all_routes.values(): - for t in route: - if t in self._outputs: + for routes in all_routes.values(): + for route in routes: + if isinstance(route, Output): continue - in_degree[t] += 1 + in_degree[route] += 1 root_nodes = [n for n in self._nodes if in_degree[n] == 0] queue = collections.deque(root_nodes) @@ -118,7 +113,7 @@ def _validate_and_build(self) -> None: current = queue.popleft() ordered.append(current) for child in all_routes.get(current, []): - if child in self._outputs: + if isinstance(child, Output): continue in_degree[child] -= 1 if in_degree[child] == 0: @@ -128,19 +123,21 @@ def _validate_and_build(self) -> None: nodes_in_cycle = set(self._nodes) - set(ordered) raise ValueError(f"DAG contains a cycle. Nodes in cycle: {nodes_in_cycle}") - # check all terminal Arbiter nodes have correct outputs + # check all terminal Arbiter nodes have correct output types terminal_nodes = [n for n in self._nodes if not all_routes.get(n)] for terminal in terminal_nodes: - entity = all_named_entities[terminal] - if isinstance(entity, Arbiter): - if any(o.name not in self._outputs for o in entity.outputs()): + node = self._nodes[terminal] + if isinstance(node, Arbiter): + if not issubclass(node.output_type, self._output_type): raise ValueError( - f"Terminal Arbiter node {terminal!r} has output(s) that are not declared as outputs in the DAG constructor." + f"Terminal Arbiter node {terminal} has output_type {node.output_type}, which is not compatible with the DAG's output_type {self._output_type}." ) - # get predecessors + # build predecessors for name, node in self._nodes.items(): for target in node.all_routes(): + if isinstance(target, Output): + continue self._predecessors[target].append(name) self._validated = True @@ -242,7 +239,6 @@ def total_costs(self) -> dict[str, float]: reachable: set[str] = set(self._root_nodes) path: list[str] = [] total = 0.0 - terminal_outputs: list[str] = [] for node_name in self._ordered: if node_name not in reachable: @@ -257,21 +253,17 @@ def total_costs(self) -> dict[str, float]: else node.routes_false ) elif isinstance(node, Arbiter): - terminal_outputs = [o.name for o in node.outputs()] targets = [] else: targets = node.routes for target in targets: - if isinstance(target, Output): - terminal_outputs = [target.name] - else: + if not isinstance(target, Output): reachable.add( target if isinstance(target, str) else target.name ) base_path = " -> ".join(path) - for output_name in terminal_outputs: - path_costs[f"{base_path} -> {output_name}"] = total + path_costs[f"{base_path} -> {self._output_type}"] = total return path_costs diff --git a/src/modelplane/evaluator/nodes.py b/src/modelplane/evaluator/nodes.py index 7e1aadc..7b0d58a 100644 --- a/src/modelplane/evaluator/nodes.py +++ b/src/modelplane/evaluator/nodes.py @@ -5,10 +5,9 @@ EvaluatorNode (ABC) ├── Gate (binary test; routes on True/False) - ├── Enricher (transforms context; routes unconditionally) - ├── Scorer (produces a float score; routes unconditionally) - └── Arbiter (produces output) - Output (terminal node; carries a verdict value) + ├── Enricher (produces arbitary output; routes forward unconditionally) + ├── Arbiter (produces output; routes to outputs only) + └── Output (terminal node; carries a verdict value) """ from abc import ABC, abstractmethod @@ -27,21 +26,21 @@ def __init__( routes: Optional[Sequence[str | Output]] = None, ) -> None: self.name = name - self._routes_true: tuple[str | Output] = tuple(routes_true or []) - self._routes_false: tuple[str | Output] = tuple(routes_false or []) - self._routes: tuple[str | Output] = tuple(routes or []) + self._routes_true: tuple[str | Output, ...] = tuple(routes_true or []) + self._routes_false: tuple[str | Output, ...] = tuple(routes_false or []) + self._routes: tuple[str | Output, ...] = tuple(routes or []) self.validate() @property - def routes_true(self) -> tuple[str | Output]: + def routes_true(self) -> tuple[str | Output, ...]: return self._routes_true @property - def routes_false(self) -> tuple[str | Output]: + def routes_false(self) -> tuple[str | Output, ...]: return self._routes_false @property - def routes(self) -> tuple[str | Output]: + def routes(self) -> tuple[str | Output, ...]: return self._routes @abstractmethod @@ -64,15 +63,11 @@ def format_output(self, output: Any) -> str: s = str(output) return s if len(s) <= 30 else s[:27] + "..." - def all_routes(self) -> list[str]: + def all_routes(self) -> list[str | Output]: """Return a list of all route targets from this node.""" - return [ - *[r if isinstance(r, str) else r.name for r in self.routes_true], - *[r if isinstance(r, str) else r.name for r in self.routes_false], - *[r if isinstance(r, str) else r.name for r in self.routes], - ] + return [*self.routes_true, *self.routes_false, *self.routes] - def next_nodes(self, output: Any) -> tuple[str | Output]: + def next_nodes(self, output: Any) -> tuple[str | Output, ...]: """Given the node's output value, return the tuple of next node names to activate.""" if isinstance(self, Gate): return self.routes_true if output else self.routes_false @@ -137,18 +132,6 @@ def validate(self) -> None: _validate_unary_routes(self) -class Scorer(EvaluatorDAGNode): - """Scoring node. Produces a float score from the (possibly enriched) context.""" - - @abstractmethod - def run(self, ctx: EvalContext) -> float: - """Return a score for the current context.""" - - def validate(self) -> None: - super().validate() - _validate_unary_routes(self) - - class Arbiter(EvaluatorDAGNode): """Takes context and returns an Output indicating the final verdict (based on routes).""" @@ -160,6 +143,8 @@ def validate(self) -> None: super().validate() _validate_terminal(self) + @property @abstractmethod - def outputs(self) -> list[Output]: - """Return the list of possible Output verdicts this Arbiter can return.""" + def output_type(self) -> type: + """Return the expected type of the Output's value for validation.""" + raise NotImplementedError diff --git a/src/modelplane/evaluator/safety.py b/src/modelplane/evaluator/safety.py index dcdb30e..98e7fb9 100644 --- a/src/modelplane/evaluator/safety.py +++ b/src/modelplane/evaluator/safety.py @@ -18,10 +18,6 @@ def name(self) -> str: return "SAFE" if self.is_safe else "UNSAFE" -SAFE = Safety(is_safe=True) -UNSAFE = Safety(is_safe=False) - - class SafetyDAGAnnotator(DAGAnnotator): """Implementation of DAGAnnotator that produces a SafetyAnnotation.""" @@ -50,7 +46,4 @@ def run(self, ctx: EvalContext) -> Safety: prompt = TextPrompt(text=ctx.prompt) response = SUTResponse(text=ctx.response) annotation = self.annotator.process(prompt, response) - return SAFE if annotation.is_safe else UNSAFE - - def outputs(self) -> list[Safety]: - return [SAFE, UNSAFE] + return Safety(is_safe=annotation.is_safe) From 405297881666e5dbcb30488b4573b09e1cd0f26c Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 13 Apr 2026 14:32:50 -0400 Subject: [PATCH 16/19] Visualization updates. --- src/modelplane/evaluator/dag.py | 171 ++++++++++++++++++++++------ src/modelplane/evaluator/outputs.py | 2 +- src/modelplane/evaluator/safety.py | 10 +- 3 files changed, 144 insertions(+), 39 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index 14c42fd..c5e4e7a 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -57,6 +57,10 @@ def __init__(self, name: str, output_type: type) -> None: raise ValueError("output_type must be a subclass of Output.") self._output_type = output_type + @property + def output_type(self) -> type: + return self._output_type + def add_node( self, node: EvaluatorDAGNode, @@ -130,7 +134,7 @@ def _validate_and_build(self) -> None: if isinstance(node, Arbiter): if not issubclass(node.output_type, self._output_type): raise ValueError( - f"Terminal Arbiter node {terminal} has output_type {node.output_type}, which is not compatible with the DAG's output_type {self._output_type}." + f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self._output_type.__name__}." ) # build predecessors @@ -263,7 +267,7 @@ def total_costs(self) -> dict[str, float]: ) base_path = " -> ".join(path) - path_costs[f"{base_path} -> {self._output_type}"] = total + path_costs[f"{base_path} -> Out ({self._output_type.__name__})"] = total return path_costs @@ -272,6 +276,7 @@ def _visualize( node_outputs: Optional[dict[str, Any]] = None, traversed_edges: Optional[set[tuple[str, str]]] = None, final_output: Optional[Output] = None, + ctx: Optional[EvalContext] = None, ): """Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline. @@ -284,11 +289,20 @@ def _visualize( traced = node_outputs is not None _NODE_STYLES: dict[type, dict] = { - Gate: {"shape": "diamond", "style": "filled", "fillcolor": "#d0e8f5"}, - Arbiter: {"shape": "box", "style": "filled", "fillcolor": "#c8e6c9"}, - Output: {"shape": "ellipse", "style": "filled", "fillcolor": "#fff9c4"}, + Gate: {"shape": "diamond", "style": "filled", "fillcolor": "#ffe082"}, + Arbiter: {"shape": "hexagon", "style": "filled", "fillcolor": "#e1bee7"}, + Output: { + "shape": "rectangle", + "style": "filled,rounded", + "fillcolor": "#dcedc8", + }, + } + _OUTPUT_TYPE_STYLE = { + "shape": "rectangle", + "style": "filled,rounded,dashed", + "fillcolor": "#dcedc8", } - _DEFAULT_STYLE = {"shape": "box", "style": "filled", "fillcolor": "#ffe0b2"} + _DEFAULT_STYLE = {"shape": "rectangle", "style": "filled", "fillcolor": "#eeeeee"} _DIM = { "style": "filled", "fillcolor": "#f0f0f0", @@ -296,44 +310,104 @@ def _visualize( "fontcolor": "#aaaaaa", } + _NODE_W, _NODE_H = 1.5, 0.5 # inches, fixed for all nodes + + def _fontsize( + label: str, max_fs: float = 11.0, min_fs: float = 7.0, fill: float = 0.8 + ) -> str: + """Scale font size so the longest line fits within _NODE_W. + + fill: fraction of the node width usable for text. Shapes like diamonds, + hexagons, and parallelograms have less usable area than rectangles, so + pass a smaller fill value for those. + """ + longest = max((len(line) for line in label.split("\n")), default=1) + # approx: each char ≈ 0.55 × fontsize points + fs = (_NODE_W * 72 * fill) / (longest * 0.55) + return f"{max(min_fs, min(max_fs, fs)):.1f}" + dot = graphviz.Digraph(name=self.name) dot.attr( label=self.name, labelloc="t", fontsize="13", fontname="Helvetica", - rankdir="TB", + rankdir="LR", ranksep="0.5", nodesep="0.4", ) - dot.attr("node", fontname="Helvetica", fontsize="11") - dot.attr("edge", fontname="Helvetica", fontsize="10") + dot.attr( + "node", + fontname="Helvetica", + fontsize="11", + width=str(_NODE_W), + height=str(_NODE_H), + fixedsize="true", + ) + dot.attr("edge", fontname="Helvetica", fontsize="9") - # implicit input node pinned to the top + # implicit input node pinned to the left top = graphviz.Digraph() top.attr(rank="min") + + def _truncate(s: str, n: int = 24) -> str: + return s if len(s) <= n else s[: n - 1] + "…" + + if ctx is not None: + input_label = f"p: {_truncate(ctx.prompt)}\nr: {_truncate(ctx.response)}" + else: + input_label = "prompt\nresponse" top.node( "__input__", - "prompt\nresponse", - shape="box", - style="dashed", - fillcolor="white", - color="#888888", - fontcolor="#555555", + input_label, + shape="parallelogram", + style="filled", + fillcolor="#b2dfdb", + color="#4db6ac", + fontcolor="#00695c", + fontsize=_fontsize(input_label, fill=0.45), ) dot.subgraph(top) - # output terminal nodes pinned to the bottom + # collect Output instances directly referenced in routes (from non-Arbiter nodes) + direct_outputs: dict[str, Output] = {} + has_arbiter = any(isinstance(n, Arbiter) for n in self._nodes.values()) + for node in self._nodes.values(): + if not isinstance(node, Arbiter): + for target in node.all_routes(): + if isinstance(target, Output): + direct_outputs[target.name] = target + + # whether the final output came from a direct route or an arbiter + final_from_direct = traced and final_output in direct_outputs.values() + bottom = graphviz.Digraph() bottom.attr(rank="max") - for output_name, output_node in self._outputs.items(): + + # individual nodes for directly-routed Output instances, shown with their repr + for out_name, out_inst in direct_outputs.items(): attrs = dict(_NODE_STYLES[Output]) if traced: - if output_node is final_output: + if out_inst is final_output: attrs["penwidth"] = "2.5" else: - attrs = dict(_DIM, shape="ellipse") - bottom.node(output_name, **attrs) + attrs = dict(_DIM, shape="rectangle", style="filled,rounded") + bottom.node(out_name, repr(out_inst), fontsize=_fontsize(repr(out_inst)), **attrs) + + # synthetic output type node for Arbiters + if has_arbiter: + output_node_id = f"__output_{self._output_type.__name__}__" + output_label = f"{self._output_type.__name__} (?)" + attrs = dict(_OUTPUT_TYPE_STYLE) + if traced: + if not final_from_direct and final_output is not None: + attrs = dict(_NODE_STYLES[Output]) + attrs["penwidth"] = "2.5" + output_label = repr(final_output) + elif final_from_direct: + attrs = dict(_DIM, shape="rectangle", style="filled,rounded") + bottom.node(output_node_id, output_label, fontsize=_fontsize(output_label), **attrs) + dot.subgraph(bottom) # processing nodes @@ -359,13 +433,12 @@ def _visualize( attrs["penwidth"] = "2.5" else: label = node_name - dot.node(node_name, label, **attrs) + _fill = 0.45 if isinstance(node, Gate) else 0.65 if isinstance(node, Arbiter) else 0.8 + dot.node(node_name, label, fontsize=_fontsize(label, fill=_fill), **attrs) - # dashed edges from implicit input to root nodes + # edges from implicit input to root nodes for root in self._root_nodes: - dot.edge( - "__input__", root, style="dashed", color="#888888", arrowhead="open" - ) + dot.edge("__input__", root, color="#888888") # edges between processing nodes for node_name, node in self._nodes.items(): @@ -393,22 +466,27 @@ def _visualize( penwidth="2" if hot and traced else "1", ) elif isinstance(node, Arbiter): - for output in node.outputs(): - hot = not traced or (node_name, output.name) in traversed_edges # type: ignore[operator] - dot.edge( - node_name, - output.name, - color="#555555" if hot else "#cccccc", - penwidth="2" if hot and traced else "1", - ) + output_node_id = f"__output_{self._output_type.__name__}__" + hot = not traced or node_name in (node_outputs or {}) + dot.edge( + node_name, + output_node_id, + color="#555555" if hot else "#cccccc", + penwidth="2" if hot and traced else "1", + ) else: for target in node.routes: t = target if isinstance(target, str) else target.name hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator] + edge_label = "" + if traced and hot and node_name in (node_outputs or {}): + edge_label = f" {node.format_output(node_outputs[node_name])}" # type: ignore[index] dot.edge( node_name, t, + label=edge_label, color="#555555" if hot else "#cccccc", + fontcolor="#555555" if hot else "#cccccc", penwidth="2" if hot and traced else "1", ) @@ -424,15 +502,36 @@ def _visualize( @requires_validate_and_build def visualize(self): - """Visualize the DAG structure without execution.""" + """Render the DAG structure as a PNG image (inline in Jupyter notebooks). + + The graph flows left to right. Node shapes and colors: + - Input — teal parallelogram (implicit; represents the prompt/response pair) + - Gate — amber diamond; edges labelled "True" (green) / "False" (red) + - Enricher — light grey rectangle; edges are unlabelled + - Arbiter — light purple hexagon; edge labelled with the output type name + - Output (direct instance) — soft green rounded rectangle, solid border; + label is repr(output) + - Output (type placeholder) — soft green rounded rectangle, dashed border; + label is the class name; shown when the DAG contains + an Arbiter whose concrete value is only known at runtime + + Raises: + RuntimeError: if the Graphviz system binaries are not installed. + """ return self._visualize() @requires_validate_and_build def visualize_run(self, ctx: EvalContext): - """Run the DAG on ctx and return a visualization with the executed path highlighted.""" + """Run the DAG on ctx and return a visualization with the executed path highlighted. + + Identical layout to visualize(), with the following additions: + - Active nodes are bolded and show their output value beneath the node name. + - Inactive nodes are greyed out. + """ final_output, node_outputs, traversed_edges = self._run_traced(ctx) return self._visualize( node_outputs=node_outputs, traversed_edges=traversed_edges, final_output=final_output, + ctx=ctx, ) diff --git a/src/modelplane/evaluator/outputs.py b/src/modelplane/evaluator/outputs.py index e086a4c..9810492 100644 --- a/src/modelplane/evaluator/outputs.py +++ b/src/modelplane/evaluator/outputs.py @@ -8,4 +8,4 @@ def name(self) -> str: """Return a string name for this output, used for routing and debugging.""" def __repr__(self) -> str: - return f"{self.name} ({self.__class__.__name__})" + return f"{self.__class__.__name__} ({self.name})" diff --git a/src/modelplane/evaluator/safety.py b/src/modelplane/evaluator/safety.py index 98e7fb9..0bf5dbf 100644 --- a/src/modelplane/evaluator/safety.py +++ b/src/modelplane/evaluator/safety.py @@ -18,12 +18,18 @@ def name(self) -> str: return "SAFE" if self.is_safe else "UNSAFE" +class SafetyArbiter(Arbiter): + @property + def output_type(self) -> type: + return Safety + + class SafetyDAGAnnotator(DAGAnnotator): """Implementation of DAGAnnotator that produces a SafetyAnnotation.""" def __init__(self, uid: str, dag: EvaluatorDAG) -> None: super().__init__(uid, dag) - if not all(isinstance(o, Safety) for o in dag.outputs): + if not issubclass(dag.output_type, Safety): raise ValueError("All outputs of the DAG must be of type Safety.") def translate_response( @@ -35,7 +41,7 @@ def translate_response( return SafetyAnnotation(is_safe=response.is_safe) -class AnnotatorArbiter(Arbiter): +class AnnotatorArbiter(SafetyArbiter): """Arbiter that outputs SAFE or UNSAFE based on the output of a (safety) Annotator.""" def __init__(self, name: str, annotator: Annotator) -> None: From 65a66da730c7c72436805434802f74128b707979 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 13 Apr 2026 14:33:11 -0400 Subject: [PATCH 17/19] Test updates. --- tests/unit/evaluator/conftest.py | 29 ++++++++++++---- tests/unit/evaluator/mocks.py | 51 +++++++++++++++------------- tests/unit/evaluator/test_dag.py | 52 ++++++++++++++++++++--------- tests/unit/evaluator/test_nodes.py | 4 +-- tests/unit/evaluator/test_safety.py | 25 ++++++++++++++ 5 files changed, 114 insertions(+), 47 deletions(-) create mode 100644 tests/unit/evaluator/test_safety.py diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index a5850f5..b044152 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -7,7 +7,7 @@ from modelplane.evaluator.context import EvalContext from modelplane.evaluator.dag import EvaluatorDAG from modelplane.evaluator.outputs import Output -from modelplane.evaluator.safety import SAFE, UNSAFE +from modelplane.evaluator.safety import Safety from .mocks import ( AlwaysFalse, @@ -97,10 +97,25 @@ def threshold_arbiter() -> ThresholdArbiter: return ThresholdArbiter(name="threshold_arbiter", threshold=1.5) +@pytest.fixture +def one_step_dag(): + return ( + EvaluatorDAG("one_step", output_type=Safety) + .add_node( + AlwaysFalse( + name="gate", + routes_true=[Safety(is_safe=True)], + routes_false=["always_unsafe"], + ) + ) + .add_node(AlwaysUnsafe(name="always_unsafe")) + ) + + @pytest.fixture def simple_dag(): return ( - EvaluatorDAG("simple", outputs=[SAFE, UNSAFE]) + EvaluatorDAG("simple", output_type=Safety) .add_node( AlwaysTrue( name="always_true", @@ -112,7 +127,7 @@ def simple_dag(): .add_node( PromptLengthGate( name="prompt_parity", - routes_true=[UNSAFE], + routes_true=[Safety(is_safe=False)], routes_false=["upper_caser"], ) ) @@ -131,7 +146,7 @@ def simple_dag(): @pytest.fixture() def bad_dag_with_cycle(): return ( - EvaluatorDAG("cyclic", outputs=[SAFE, UNSAFE]) + EvaluatorDAG("cyclic", output_type=Safety) .add_node( AlwaysTrue( name="node1", @@ -149,8 +164,8 @@ def bad_dag_with_cycle(): .add_node( AlwaysTrue( name="node3", - routes_true=[SAFE], - routes_false=[UNSAFE], + routes_true=[Safety(is_safe=True)], + routes_false=[Safety(is_safe=False)], ) ) ) @@ -165,6 +180,6 @@ def bad_dag_with_undefined_output(simple_dag): @pytest.fixture def bad_dag_with_bad_arbiter(): - dag = EvaluatorDAG("test", outputs=[SAFE, UNSAFE]) + dag = EvaluatorDAG("test", output_type=Safety) dag.add_node(BadArbiter(name="bad_arbiter")) return dag diff --git a/tests/unit/evaluator/mocks.py b/tests/unit/evaluator/mocks.py index 3bd7b51..0e1307d 100644 --- a/tests/unit/evaluator/mocks.py +++ b/tests/unit/evaluator/mocks.py @@ -1,7 +1,7 @@ from modelplane.evaluator.context import EvalContext -from modelplane.evaluator.nodes import Arbiter, Enricher, Gate, Scorer +from modelplane.evaluator.nodes import Arbiter, Enricher, Gate from modelplane.evaluator.outputs import Output -from modelplane.evaluator.safety import SAFE, UNSAFE +from modelplane.evaluator.safety import Safety class PassthroughGate(Gate): @@ -62,7 +62,7 @@ def cost(self, ctx: EvalContext) -> float: return 0.6 -class FixedScorer(Scorer): +class FixedScorer(Enricher): """Returns a fixed float score regardless of context.""" def __init__(self, name: str, value: float, **kwargs): @@ -76,7 +76,7 @@ def cost(self, ctx: EvalContext) -> float: return 0.7 -class LowerCaseScorer(Scorer): +class LowerCaseScorer(Enricher): """Scores based on the percentage of lowercase characters in the response.""" def run(self, ctx: EvalContext) -> float: @@ -89,7 +89,7 @@ def cost(self, ctx: EvalContext) -> float: return 0.8 -class UpperCaseScorer(Scorer): +class UpperCaseScorer(Enricher): """Scores based on the percentage of uppercase characters in the response.""" def run(self, ctx: EvalContext) -> float: @@ -104,25 +104,27 @@ def cost(self, ctx: EvalContext) -> float: class AlwaysUnsafe(Arbiter): def run(self, ctx: EvalContext) -> Output: - return UNSAFE - - def outputs(self) -> list[Output]: - return [UNSAFE] + return Safety(is_safe=False) def cost(self, ctx: EvalContext) -> float: return 1.0 + @property + def output_type(self) -> type: + return Safety + class AlwaysSafe(Arbiter): def run(self, ctx: EvalContext) -> Output: - return SAFE - - def outputs(self) -> list[Output]: - return [SAFE] + return Safety(is_safe=True) def cost(self, ctx: EvalContext) -> float: return 1.1 + @property + def output_type(self) -> type: + return Safety + class ThresholdArbiter(Arbiter): def __init__(self, name: str, threshold: float, **kwargs): @@ -132,14 +134,15 @@ def __init__(self, name: str, threshold: float, **kwargs): def run(self, ctx: EvalContext) -> Output: scores = ctx.parent_outputs() score = sum(scores) / len(scores) - return UNSAFE if score >= self.threshold else SAFE - - def outputs(self) -> list[Output]: - return [UNSAFE, SAFE] + return Safety(is_safe=score < self.threshold) def cost(self, ctx: EvalContext) -> float: return 1.2 + @property + def output_type(self) -> type: + return Safety + class UnexpectedOutput(Output): @property @@ -153,12 +156,13 @@ class UnexpectedArbiter(Arbiter): def run(self, ctx: EvalContext) -> Output: return UnexpectedOutput() - def outputs(self) -> list[Output]: - return [UnexpectedOutput()] - def cost(self, ctx: EvalContext) -> float: return 1.3 + @property + def output_type(self) -> type: + return UnexpectedOutput + class BadArbiter(Arbiter): """An arbiter that violates the contract by returning a non-Output value.""" @@ -166,8 +170,9 @@ class BadArbiter(Arbiter): def run(self, ctx: EvalContext) -> str: return "safe" - def outputs(self) -> list[Output]: - return [SAFE] - def cost(self, ctx: EvalContext) -> float: return 1.4 + + @property + def output_type(self) -> type: + return Safety diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index dea2544..49e9bda 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -1,15 +1,26 @@ """Unit tests for EvaluatorDAG construction, validation, execution, and visualization.""" +from unittest.mock import patch + import pandas as pd import pytest -from modelplane.evaluator.safety import SAFE, UNSAFE +from modelplane.evaluator.dag import EvaluatorDAG +from modelplane.evaluator.safety import Safety from .conftest import skip_in_ci def test_dag_outputs(simple_dag): - assert simple_dag.outputs == [SAFE, UNSAFE] + assert simple_dag.output_type == Safety + + +def test_dag_with_bad_output_type(): + with pytest.raises( + ValueError, + match="output_type must be a subclass of Output", + ): + EvaluatorDAG(name="bad_dag", output_type=str) def test_add_node_with_same_name_as_existing_node(simple_dag, always_true_gate): @@ -18,12 +29,6 @@ def test_add_node_with_same_name_as_existing_node(simple_dag, always_true_gate): simple_dag.add_node(always_true_gate) # same name as existing node -def test_add_node_with_same_name_as_output(simple_dag, always_true_gate): - always_true_gate.name = SAFE.name - with pytest.raises(ValueError, match="is already registered"): - simple_dag.add_node(always_true_gate) # same name as existing output - - def test_add_node_with_undefined_target_node(simple_dag, bad_gate): simple_dag.add_node(bad_gate) with pytest.raises(ValueError, match="routes to unregistered node"): @@ -36,7 +41,9 @@ def test_dag_with_cycle(bad_dag_with_cycle): def test_dag_with_undefined_output(bad_dag_with_undefined_output): - with pytest.raises(ValueError, match=r"has output\(s\) that are not declared"): + with pytest.raises( + ValueError, match=r"which is not compatible with the DAG\'s output_type" + ): bad_dag_with_undefined_output._validate_and_build() @@ -102,16 +109,31 @@ def test_dag_cost_all_paths(simple_dag): costs = simple_dag.total_costs() assert costs == pytest.approx( { - "always_true -> always_safe -> SAFE": 1.2, - "always_true -> lower_caser -> prompt_parity -> lower_scorer -> upper_scorer -> threshold_arbiter -> SAFE": 3.7, - "always_true -> lower_caser -> prompt_parity -> lower_scorer -> upper_scorer -> threshold_arbiter -> UNSAFE": 3.7, - "always_true -> lower_caser -> prompt_parity -> upper_caser -> lower_scorer -> upper_scorer -> threshold_arbiter -> SAFE": 4.2, - "always_true -> lower_caser -> prompt_parity -> upper_caser -> lower_scorer -> upper_scorer -> threshold_arbiter -> UNSAFE": 4.2, + "always_true -> always_safe -> Out (Safety)": 1.2, + "always_true -> lower_caser -> prompt_parity -> lower_scorer -> upper_scorer -> threshold_arbiter -> Out (Safety)": 3.7, + "always_true -> lower_caser -> prompt_parity -> upper_caser -> lower_scorer -> upper_scorer -> threshold_arbiter -> Out (Safety)": 4.2, } ) @skip_in_ci -def test_dag_visualize_runs(simple_dag, sample_ctx): +def test_dag_visualize_runs(simple_dag, one_step_dag, sample_ctx): simple_dag.visualize() simple_dag.visualize_run(sample_ctx) + one_step_dag.visualize() + one_step_dag.visualize_run(sample_ctx) + + +def test_visualize_raises_when_graphviz_binary_missing(simple_dag): + import graphviz + + with patch.object( + graphviz.Digraph, + "pipe", + side_effect=graphviz.ExecutableNotFound(["dot"]), + ): + with pytest.raises( + RuntimeError, + match="Graphviz system binaries not found", + ): + simple_dag.visualize() diff --git a/tests/unit/evaluator/test_nodes.py b/tests/unit/evaluator/test_nodes.py index 6bd5850..10725a9 100644 --- a/tests/unit/evaluator/test_nodes.py +++ b/tests/unit/evaluator/test_nodes.py @@ -2,7 +2,7 @@ import pytest -from modelplane.evaluator.safety import SAFE, UNSAFE +from modelplane.evaluator.safety import Safety from .conftest import DEFAULT_BRANCH, FALSE_BRANCH, SCORE1, SCORE2, TRUE_BRANCH from .mocks import AlwaysTrue, AlwaysUnsafe, LowerCaser @@ -57,7 +57,7 @@ def test_gate_with_two_outputs(): with pytest.raises(ValueError, match="has multiple Output routes"): AlwaysTrue( name="bad_gate", - routes_true=[SAFE, UNSAFE], + routes_true=[Safety(is_safe=True), Safety(is_safe=False)], routes_false=FALSE_BRANCH, ) diff --git a/tests/unit/evaluator/test_safety.py b/tests/unit/evaluator/test_safety.py new file mode 100644 index 0000000..2999166 --- /dev/null +++ b/tests/unit/evaluator/test_safety.py @@ -0,0 +1,25 @@ +from modelgauge.annotation import SafetyAnnotation +from modelgauge.annotators.demo_annotator import DemoYBadAnnotator +from modelgauge.prompt import TextPrompt +from modelgauge.sut import SUTResponse + +from modelplane.evaluator.safety import AnnotatorArbiter, Safety, SafetyDAGAnnotator + + +def test_safety_annotator_arbiter(sample_ctx): + annotator = DemoYBadAnnotator("demo_annotator") + arbiter = AnnotatorArbiter(name="demo_arbiter", annotator=annotator) + output = arbiter.run(sample_ctx) + assert output.is_safe + assert isinstance(output, Safety) + assert arbiter.output_type == Safety + + +def test_safety_dag_run(simple_dag, sample_ctx): + safety_annotator = SafetyDAGAnnotator("safety", simple_dag) + output = safety_annotator.process( + prompt=TextPrompt(text=sample_ctx.prompt), + response=SUTResponse(text=sample_ctx.response), + ) + assert not output.is_safe + assert isinstance(output, SafetyAnnotation) From f72bdfb726e19d1d84c72fbb2c8050df5684e3a0 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 13 Apr 2026 14:43:39 -0400 Subject: [PATCH 18/19] Use output type property. --- src/modelplane/evaluator/dag.py | 36 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index c5e4e7a..d143a12 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -95,7 +95,7 @@ def _validate_and_build(self) -> None: for node_name, node in self._nodes.items(): for target in node.all_routes(): if target not in self._nodes and not isinstance( - target, self._output_type + target, self.output_type ): raise ValueError( f"Node {node_name} routes to unregistered node {target} or incompatible output." @@ -132,9 +132,9 @@ def _validate_and_build(self) -> None: for terminal in terminal_nodes: node = self._nodes[terminal] if isinstance(node, Arbiter): - if not issubclass(node.output_type, self._output_type): + if not issubclass(node.output_type, self.output_type): raise ValueError( - f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self._output_type.__name__}." + f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}." ) # build predecessors @@ -267,7 +267,7 @@ def total_costs(self) -> dict[str, float]: ) base_path = " -> ".join(path) - path_costs[f"{base_path} -> Out ({self._output_type.__name__})"] = total + path_costs[f"{base_path} -> Out ({self.output_type.__name__})"] = total return path_costs @@ -282,6 +282,8 @@ def _visualize( When node_outputs/traversed_edges/final_output are provided (via visualize_run), the hot path is highlighted and each node shows its output value. + + NOTE: this helper method is vibe-coded and provided as-is. """ import graphviz from IPython.display import Image @@ -302,7 +304,11 @@ def _visualize( "style": "filled,rounded,dashed", "fillcolor": "#dcedc8", } - _DEFAULT_STYLE = {"shape": "rectangle", "style": "filled", "fillcolor": "#eeeeee"} + _DEFAULT_STYLE = { + "shape": "rectangle", + "style": "filled", + "fillcolor": "#eeeeee", + } _DIM = { "style": "filled", "fillcolor": "#f0f0f0", @@ -392,12 +398,14 @@ def _truncate(s: str, n: int = 24) -> str: attrs["penwidth"] = "2.5" else: attrs = dict(_DIM, shape="rectangle", style="filled,rounded") - bottom.node(out_name, repr(out_inst), fontsize=_fontsize(repr(out_inst)), **attrs) + bottom.node( + out_name, repr(out_inst), fontsize=_fontsize(repr(out_inst)), **attrs + ) # synthetic output type node for Arbiters if has_arbiter: - output_node_id = f"__output_{self._output_type.__name__}__" - output_label = f"{self._output_type.__name__} (?)" + output_node_id = f"__output_{self.output_type.__name__}__" + output_label = f"{self.output_type.__name__} (?)" attrs = dict(_OUTPUT_TYPE_STYLE) if traced: if not final_from_direct and final_output is not None: @@ -406,7 +414,9 @@ def _truncate(s: str, n: int = 24) -> str: output_label = repr(final_output) elif final_from_direct: attrs = dict(_DIM, shape="rectangle", style="filled,rounded") - bottom.node(output_node_id, output_label, fontsize=_fontsize(output_label), **attrs) + bottom.node( + output_node_id, output_label, fontsize=_fontsize(output_label), **attrs + ) dot.subgraph(bottom) @@ -433,7 +443,11 @@ def _truncate(s: str, n: int = 24) -> str: attrs["penwidth"] = "2.5" else: label = node_name - _fill = 0.45 if isinstance(node, Gate) else 0.65 if isinstance(node, Arbiter) else 0.8 + _fill = ( + 0.45 + if isinstance(node, Gate) + else 0.65 if isinstance(node, Arbiter) else 0.8 + ) dot.node(node_name, label, fontsize=_fontsize(label, fill=_fill), **attrs) # edges from implicit input to root nodes @@ -466,7 +480,7 @@ def _truncate(s: str, n: int = 24) -> str: penwidth="2" if hot and traced else "1", ) elif isinstance(node, Arbiter): - output_node_id = f"__output_{self._output_type.__name__}__" + output_node_id = f"__output_{self.output_type.__name__}__" hot = not traced or node_name in (node_outputs or {}) dot.edge( node_name, From c9ec5d170830df11a2d49a3f8d121289a2c253a7 Mon Sep 17 00:00:00 2001 From: Vishal Doshi Date: Mon, 13 Apr 2026 16:14:40 -0400 Subject: [PATCH 19/19] Refactor validation a bit and add test. --- src/modelplane/evaluator/dag.py | 18 +++++++----------- tests/unit/evaluator/conftest.py | 16 ++++++++++++++++ tests/unit/evaluator/test_dag.py | 8 ++++++++ 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/modelplane/evaluator/dag.py b/src/modelplane/evaluator/dag.py index d143a12..b6f093b 100644 --- a/src/modelplane/evaluator/dag.py +++ b/src/modelplane/evaluator/dag.py @@ -91,8 +91,14 @@ def _validate_and_build(self) -> None: if self._validated: return - # check that all route targets reference registered nodes or instances of the output type + # check that all route targets reference registered nodes or instances + # of the output type, and that all Arbiters have compatible output types for node_name, node in self._nodes.items(): + if isinstance(node, Arbiter): + if not issubclass(node.output_type, self.output_type): + raise ValueError( + f"Node {node_name} is an Arbiter with output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}." + ) for target in node.all_routes(): if target not in self._nodes and not isinstance( target, self.output_type @@ -127,16 +133,6 @@ def _validate_and_build(self) -> None: nodes_in_cycle = set(self._nodes) - set(ordered) raise ValueError(f"DAG contains a cycle. Nodes in cycle: {nodes_in_cycle}") - # check all terminal Arbiter nodes have correct output types - terminal_nodes = [n for n in self._nodes if not all_routes.get(n)] - for terminal in terminal_nodes: - node = self._nodes[terminal] - if isinstance(node, Arbiter): - if not issubclass(node.output_type, self.output_type): - raise ValueError( - f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}." - ) - # build predecessors for name, node in self._nodes.items(): for target in node.all_routes(): diff --git a/tests/unit/evaluator/conftest.py b/tests/unit/evaluator/conftest.py index b044152..2da1064 100644 --- a/tests/unit/evaluator/conftest.py +++ b/tests/unit/evaluator/conftest.py @@ -22,6 +22,7 @@ PromptLengthGate, ThresholdArbiter, UnexpectedArbiter, + UnexpectedOutput, UpperCaser, UpperCaseScorer, ) @@ -183,3 +184,18 @@ def bad_dag_with_bad_arbiter(): dag = EvaluatorDAG("test", output_type=Safety) dag.add_node(BadArbiter(name="bad_arbiter")) return dag + + +@pytest.fixture +def bad_one_step_dag(): + return ( + EvaluatorDAG("one_step", output_type=Safety) + .add_node( + AlwaysFalse( + name="gate", + routes_true=[UnexpectedOutput()], + routes_false=["always_unsafe"], + ) + ) + .add_node(AlwaysUnsafe(name="always_unsafe")) + ) diff --git a/tests/unit/evaluator/test_dag.py b/tests/unit/evaluator/test_dag.py index 49e9bda..0158e04 100644 --- a/tests/unit/evaluator/test_dag.py +++ b/tests/unit/evaluator/test_dag.py @@ -55,6 +55,14 @@ def test_dag_with_bad_arbiter(bad_dag_with_bad_arbiter, sample_ctx): bad_dag_with_bad_arbiter.run(sample_ctx) +def test_dag_with_bad_output_route(bad_one_step_dag, sample_ctx): + with pytest.raises( + ValueError, + match=r"incompatible output", + ): + bad_one_step_dag.run(sample_ctx) + + def test_dag_run(simple_dag, sample_ctx): result = simple_dag.run(sample_ctx) assert result.name == "UNSAFE"