Skip to content

Commit 0db4854

Browse files
committed
Nice visualization plus other fixes.
1 parent 7c0aca9 commit 0db4854

4 files changed

Lines changed: 222 additions & 42 deletions

File tree

Dockerfile.jupyter

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
FROM python:3.12-slim
22

33
ENV PATH="/root/.local/bin:$PATH"
4+
ENV PYTHONPATH="/app/flightpaths/flights"
45
# Used for the notebook server
56
WORKDIR /app
67

78
# pipx needed for uv installation script
89
# ssh client needed for installing private modelbench dependencies
910
# git needed dvc
10-
RUN apt-get update && apt-get install -y pipx openssh-client git && \
11+
RUN apt-get update && apt-get install -y pipx openssh-client git graphviz && \
1112
pipx install uv
1213
COPY pyproject.toml uv.lock README.md ./
1314

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
"scikit-learn>=1.5.0,<2.0.0",
2424
"pandas>=2.2.2,<4",
2525
"modelbench @ git+https://github.com/mlcommons/modelbench.git",
26+
"graphviz>=0.20,<1",
2627
]
2728

2829
[project.scripts]

src/modelplane/evaluator/dag.py

Lines changed: 208 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
"""DAGAnnotator and EvaluatorDAG implementation."""
22

33
import collections
4-
from dataclasses import dataclass, field
54
import functools
5+
import os
6+
from concurrent.futures import ThreadPoolExecutor
7+
from dataclasses import dataclass, field
68
from typing import Any, Optional
79

810
import pandas as pd
11+
from modelgauge.annotation import SafetyAnnotation
912
from modelgauge.annotator import Annotator
10-
11-
from modelplane.evaluator.context import EvalContext
12-
from modelplane.evaluator.nodes import (
13-
Arbiter,
14-
EvaluatorDAGNode,
15-
Gate,
16-
Output,
17-
)
1813
from modelgauge.prompt import ChatPrompt, TextPrompt
1914
from modelgauge.prompt_formatting import format_chat
2015
from modelgauge.sut import SUTResponse
21-
from modelgauge.annotation import SafetyAnnotation
16+
17+
from modelplane.evaluator.context import EvalContext
18+
from modelplane.evaluator.nodes import Arbiter, EvaluatorDAGNode, Gate, Output
2219

2320

2421
def requires_validate_and_build(method):
@@ -110,6 +107,8 @@ def _validate_and_build(self) -> None:
110107
in_degree: dict[str, int] = {n: 0 for n in self._nodes}
111108
for route in all_routes.values():
112109
for t in route:
110+
if t in self._outputs:
111+
continue
113112
in_degree[t] += 1
114113

115114
root_nodes = [n for n in self._nodes if in_degree[n] == 0]
@@ -119,6 +118,8 @@ def _validate_and_build(self) -> None:
119118
current = queue.popleft()
120119
ordered.append(current)
121120
for child in all_routes.get(current, []):
121+
if child in self._outputs:
122+
continue
122123
in_degree[child] -= 1
123124
if in_degree[child] == 0:
124125
queue.append(child)
@@ -155,44 +156,54 @@ def _validate_and_build(self) -> None:
155156
self._root_nodes = root_nodes
156157
self._ordered = ordered
157158

158-
@requires_validate_and_build
159-
def run(
160-
self,
161-
ctx: EvalContext,
162-
) -> Output:
163-
"""
164-
Execute the DAG on a single prompt/response.
165-
"""
159+
def _run_traced(
160+
self, ctx: EvalContext
161+
) -> tuple[Output, dict[str, Any], set[tuple[str, str]]]:
162+
"""Execute the DAG and return (final output, node outputs, traversed edges)."""
166163
active_nodes = self._root_nodes
167-
outputs: dict[str, Any] = {}
164+
node_outputs: dict[str, Any] = {}
165+
traversed_edges: set[tuple[str, str]] = set()
168166
while active_nodes:
169167
next_active = []
170168
for node_name in active_nodes:
171-
# set parent outputs in context for this node
172169
ctx.set_parent_outputs(
173170
{
174-
pred: outputs[pred]
171+
pred: node_outputs[pred]
175172
for pred in self._predecessors[node_name]
176-
if pred in outputs
173+
if pred in node_outputs
177174
}
178175
)
179-
# run the node
180176
node = self._nodes[node_name]
181177
output = node.run(ctx)
182178
if isinstance(output, Output):
183-
return output
184-
outputs[node_name] = output
185-
# see which nodes to activate next based on output and routing
186-
next_active.extend(node.next_nodes(output))
179+
traversed_edges.add((node_name, output.name))
180+
return output, node_outputs, traversed_edges
181+
node_outputs[node_name] = output
182+
for target in node.next_nodes(output):
183+
t = target if isinstance(target, str) else target.name
184+
traversed_edges.add((node_name, t))
185+
if isinstance(target, Output):
186+
return target, node_outputs, traversed_edges
187+
next_active.append(t)
187188
active_nodes = next_active
188189
raise ValueError("DAG execution completed without reaching an Output node.")
189190

191+
@requires_validate_and_build
192+
def run(
193+
self,
194+
ctx: EvalContext,
195+
) -> Output:
196+
"""Execute the DAG on a single prompt/response."""
197+
output, _, _ = self._run_traced(ctx)
198+
return output
199+
190200
@requires_validate_and_build
191201
def run_dataframe(
192202
self,
193203
df: pd.DataFrame,
194204
prompt_col: str = "prompt",
195205
response_col: str = "response",
206+
n_jobs: int = 1,
196207
) -> pd.DataFrame:
197208
"""Run the DAG over every row of a DataFrame."""
198209

@@ -203,7 +214,14 @@ def _run_row(row: Any) -> Output:
203214
)
204215
return self.run(ctx)
205216

206-
records = [_run_row(row) for _, row in df.iterrows()]
217+
rows = [row for _, row in df.iterrows()]
218+
219+
if n_jobs == 1:
220+
records = [_run_row(row) for row in rows]
221+
else:
222+
max_workers = os.cpu_count() if n_jobs == -1 else n_jobs
223+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
224+
records = list(executor.map(_run_row, rows))
207225

208226
result_df = pd.DataFrame(
209227
{self.DATAFRAME_OUTPUT_COL: [r.name for r in records]}, index=df.index
@@ -241,20 +259,169 @@ def _dfs(node_name: str, accumulated: float, path: list[str]) -> None:
241259
return path_costs
242260

243261
@requires_validate_and_build
244-
def visualize(self) -> None:
245-
"""Render the DAG structure with ascii."""
246-
print(f"EvaluatorDAG: {self.name!r}")
247-
print("=" * (len(self.name) + 18))
248-
for node_name in self._ordered:
249-
node = self._nodes[node_name]
250-
node_type = type(node).__name__
251-
if isinstance(node, Output):
252-
route_str = f" → verdict='{node.name}'"
253-
elif isinstance(node, Gate):
254-
route_str = f" → True:{node.routes_true} False:{node.routes_false}"
262+
def visualize(
263+
self,
264+
node_outputs: Optional[dict[str, Any]] = None,
265+
traversed_edges: Optional[set[tuple[str, str]]] = None,
266+
final_output: Optional[Output] = None,
267+
):
268+
"""Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline.
269+
270+
When node_outputs/traversed_edges/final_output are provided (via visualize_run),
271+
the hot path is highlighted and each node shows its output value.
272+
"""
273+
import graphviz
274+
from IPython.display import Image
275+
276+
traced = node_outputs is not None
277+
278+
def _format_output(value: Any) -> str:
279+
if isinstance(value, float):
280+
return f"{value:.3g}"
281+
s = str(value)
282+
return s if len(s) <= 30 else s[:27] + "..."
283+
284+
_NODE_STYLES: dict[type, dict] = {
285+
Gate: {"shape": "diamond", "style": "filled", "fillcolor": "#d0e8f5"},
286+
Arbiter: {"shape": "box", "style": "filled", "fillcolor": "#c8e6c9"},
287+
Output: {"shape": "ellipse", "style": "filled", "fillcolor": "#fff9c4"},
288+
}
289+
_DEFAULT_STYLE = {"shape": "box", "style": "filled", "fillcolor": "#ffe0b2"}
290+
_DIM = {
291+
"style": "filled",
292+
"fillcolor": "#f0f0f0",
293+
"color": "#bbbbbb",
294+
"fontcolor": "#aaaaaa",
295+
}
296+
297+
dot = graphviz.Digraph(name=self.name)
298+
dot.attr(
299+
label=self.name,
300+
labelloc="t",
301+
fontsize="13",
302+
fontname="Helvetica",
303+
rankdir="TB",
304+
ranksep="0.5",
305+
nodesep="0.4",
306+
)
307+
dot.attr("node", fontname="Helvetica", fontsize="11")
308+
dot.attr("edge", fontname="Helvetica", fontsize="10")
309+
310+
# implicit input node pinned to the top
311+
top = graphviz.Digraph()
312+
top.attr(rank="min")
313+
top.node(
314+
"__input__",
315+
"prompt\nresponse",
316+
shape="box",
317+
style="dashed",
318+
fillcolor="white",
319+
color="#888888",
320+
fontcolor="#555555",
321+
)
322+
dot.subgraph(top)
323+
324+
# output terminal nodes pinned to the bottom
325+
bottom = graphviz.Digraph()
326+
bottom.attr(rank="max")
327+
for output_name, output_node in self._outputs.items():
328+
attrs = dict(_NODE_STYLES[Output])
329+
if traced:
330+
if output_node is final_output:
331+
attrs["penwidth"] = "2.5"
332+
else:
333+
attrs = dict(_DIM, shape="ellipse")
334+
bottom.node(output_name, **attrs)
335+
dot.subgraph(bottom)
336+
337+
# processing nodes
338+
for node_name, node in self._nodes.items():
339+
base_style = next(
340+
(s for t, s in _NODE_STYLES.items() if isinstance(node, t)),
341+
_DEFAULT_STYLE,
342+
)
343+
if traced and node_name not in node_outputs:
344+
attrs = dict(_DIM, shape=base_style.get("shape", "box"))
345+
label = node_name
255346
else:
256-
route_str = f" → {node.routes}"
257-
print(f" [{node_type:10s}] {node_name}{route_str}")
347+
attrs = dict(base_style)
348+
if traced:
349+
raw = node_outputs[node_name] # type: ignore[index]
350+
label = f"{node_name}\n{_format_output(raw)}"
351+
else:
352+
label = node_name
353+
dot.node(node_name, label, **attrs)
354+
355+
# dashed edges from implicit input to root nodes
356+
for root in self._root_nodes:
357+
dot.edge(
358+
"__input__", root, style="dashed", color="#888888", arrowhead="open"
359+
)
360+
361+
# edges between processing nodes
362+
for node_name, node in self._nodes.items():
363+
if isinstance(node, Gate):
364+
for target in node.routes_true:
365+
t = target if isinstance(target, str) else target.name
366+
hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator]
367+
dot.edge(
368+
node_name,
369+
t,
370+
label=" True",
371+
color="#2e7d32" if hot else "#cccccc",
372+
fontcolor="#2e7d32" if hot else "#cccccc",
373+
penwidth="2" if hot and traced else "1",
374+
)
375+
for target in node.routes_false:
376+
t = target if isinstance(target, str) else target.name
377+
hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator]
378+
dot.edge(
379+
node_name,
380+
t,
381+
label=" False",
382+
color="#c62828" if hot else "#cccccc",
383+
fontcolor="#c62828" if hot else "#cccccc",
384+
penwidth="2" if hot and traced else "1",
385+
)
386+
elif isinstance(node, Arbiter):
387+
for output in node.outputs():
388+
hot = not traced or (node_name, output.name) in traversed_edges # type: ignore[operator]
389+
dot.edge(
390+
node_name,
391+
output.name,
392+
color="#555555" if hot else "#cccccc",
393+
penwidth="2" if hot and traced else "1",
394+
)
395+
else:
396+
for target in node.routes:
397+
t = target if isinstance(target, str) else target.name
398+
hot = not traced or (node_name, t) in traversed_edges # type: ignore[operator]
399+
dot.edge(
400+
node_name,
401+
t,
402+
color="#555555" if hot else "#cccccc",
403+
penwidth="2" if hot and traced else "1",
404+
)
405+
406+
try:
407+
return Image(dot.pipe(format="png"))
408+
except graphviz.ExecutableNotFound:
409+
raise RuntimeError(
410+
"Graphviz system binaries not found. Install them with:\n"
411+
" macOS: brew install graphviz\n"
412+
" Ubuntu: apt-get install graphviz\n"
413+
" conda: conda install graphviz"
414+
) from None
415+
416+
@requires_validate_and_build
417+
def visualize_run(self, ctx: EvalContext):
418+
"""Run the DAG on ctx and return a PNG with the executed path highlighted."""
419+
final_output, node_outputs, traversed_edges = self._run_traced(ctx)
420+
return self.visualize(
421+
node_outputs=node_outputs,
422+
traversed_edges=traversed_edges,
423+
final_output=final_output,
424+
)
258425

259426

260427
class DAGAnnotator(Annotator):

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)