11"""DAGAnnotator and EvaluatorDAG implementation."""
22
33import collections
4- from dataclasses import dataclass , field
54import functools
5+ import os
6+ from concurrent .futures import ThreadPoolExecutor
7+ from dataclasses import dataclass , field
68from typing import Any , Optional
79
810import pandas as pd
11+ from modelgauge .annotation import SafetyAnnotation
912from modelgauge .annotator import Annotator
10-
11- from modelplane .evaluator .context import EvalContext
12- from modelplane .evaluator .nodes import (
13- Arbiter ,
14- EvaluatorDAGNode ,
15- Gate ,
16- Output ,
17- )
1813from modelgauge .prompt import ChatPrompt , TextPrompt
1914from modelgauge .prompt_formatting import format_chat
2015from modelgauge .sut import SUTResponse
21- from modelgauge .annotation import SafetyAnnotation
16+
17+ from modelplane .evaluator .context import EvalContext
18+ from modelplane .evaluator .nodes import Arbiter , EvaluatorDAGNode , Gate , Output
2219
2320
2421def requires_validate_and_build (method ):
@@ -110,6 +107,8 @@ def _validate_and_build(self) -> None:
110107 in_degree : dict [str , int ] = {n : 0 for n in self ._nodes }
111108 for route in all_routes .values ():
112109 for t in route :
110+ if t in self ._outputs :
111+ continue
113112 in_degree [t ] += 1
114113
115114 root_nodes = [n for n in self ._nodes if in_degree [n ] == 0 ]
@@ -119,6 +118,8 @@ def _validate_and_build(self) -> None:
119118 current = queue .popleft ()
120119 ordered .append (current )
121120 for child in all_routes .get (current , []):
121+ if child in self ._outputs :
122+ continue
122123 in_degree [child ] -= 1
123124 if in_degree [child ] == 0 :
124125 queue .append (child )
@@ -155,44 +156,54 @@ def _validate_and_build(self) -> None:
155156 self ._root_nodes = root_nodes
156157 self ._ordered = ordered
157158
158- @requires_validate_and_build
159- def run (
160- self ,
161- ctx : EvalContext ,
162- ) -> Output :
163- """
164- Execute the DAG on a single prompt/response.
165- """
159+ def _run_traced (
160+ self , ctx : EvalContext
161+ ) -> tuple [Output , dict [str , Any ], set [tuple [str , str ]]]:
162+ """Execute the DAG and return (final output, node outputs, traversed edges)."""
166163 active_nodes = self ._root_nodes
167- outputs : dict [str , Any ] = {}
164+ node_outputs : dict [str , Any ] = {}
165+ traversed_edges : set [tuple [str , str ]] = set ()
168166 while active_nodes :
169167 next_active = []
170168 for node_name in active_nodes :
171- # set parent outputs in context for this node
172169 ctx .set_parent_outputs (
173170 {
174- pred : outputs [pred ]
171+ pred : node_outputs [pred ]
175172 for pred in self ._predecessors [node_name ]
176- if pred in outputs
173+ if pred in node_outputs
177174 }
178175 )
179- # run the node
180176 node = self ._nodes [node_name ]
181177 output = node .run (ctx )
182178 if isinstance (output , Output ):
183- return output
184- outputs [node_name ] = output
185- # see which nodes to activate next based on output and routing
186- next_active .extend (node .next_nodes (output ))
179+ traversed_edges .add ((node_name , output .name ))
180+ return output , node_outputs , traversed_edges
181+ node_outputs [node_name ] = output
182+ for target in node .next_nodes (output ):
183+ t = target if isinstance (target , str ) else target .name
184+ traversed_edges .add ((node_name , t ))
185+ if isinstance (target , Output ):
186+ return target , node_outputs , traversed_edges
187+ next_active .append (t )
187188 active_nodes = next_active
188189 raise ValueError ("DAG execution completed without reaching an Output node." )
189190
191+ @requires_validate_and_build
192+ def run (
193+ self ,
194+ ctx : EvalContext ,
195+ ) -> Output :
196+ """Execute the DAG on a single prompt/response."""
197+ output , _ , _ = self ._run_traced (ctx )
198+ return output
199+
190200 @requires_validate_and_build
191201 def run_dataframe (
192202 self ,
193203 df : pd .DataFrame ,
194204 prompt_col : str = "prompt" ,
195205 response_col : str = "response" ,
206+ n_jobs : int = 1 ,
196207 ) -> pd .DataFrame :
197208 """Run the DAG over every row of a DataFrame."""
198209
@@ -203,7 +214,14 @@ def _run_row(row: Any) -> Output:
203214 )
204215 return self .run (ctx )
205216
206- records = [_run_row (row ) for _ , row in df .iterrows ()]
217+ rows = [row for _ , row in df .iterrows ()]
218+
219+ if n_jobs == 1 :
220+ records = [_run_row (row ) for row in rows ]
221+ else :
222+ max_workers = os .cpu_count () if n_jobs == - 1 else n_jobs
223+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
224+ records = list (executor .map (_run_row , rows ))
207225
208226 result_df = pd .DataFrame (
209227 {self .DATAFRAME_OUTPUT_COL : [r .name for r in records ]}, index = df .index
@@ -241,20 +259,169 @@ def _dfs(node_name: str, accumulated: float, path: list[str]) -> None:
241259 return path_costs
242260
243261 @requires_validate_and_build
244- def visualize (self ) -> None :
245- """Render the DAG structure with ascii."""
246- print (f"EvaluatorDAG: { self .name !r} " )
247- print ("=" * (len (self .name ) + 18 ))
248- for node_name in self ._ordered :
249- node = self ._nodes [node_name ]
250- node_type = type (node ).__name__
251- if isinstance (node , Output ):
252- route_str = f" → verdict='{ node .name } '"
253- elif isinstance (node , Gate ):
254- route_str = f" → True:{ node .routes_true } False:{ node .routes_false } "
262+ def visualize (
263+ self ,
264+ node_outputs : Optional [dict [str , Any ]] = None ,
265+ traversed_edges : Optional [set [tuple [str , str ]]] = None ,
266+ final_output : Optional [Output ] = None ,
267+ ):
268+ """Render the DAG as a PNG image. In a Jupyter notebook the image is displayed inline.
269+
270+ When node_outputs/traversed_edges/final_output are provided (via visualize_run),
271+ the hot path is highlighted and each node shows its output value.
272+ """
273+ import graphviz
274+ from IPython .display import Image
275+
276+ traced = node_outputs is not None
277+
278+ def _format_output (value : Any ) -> str :
279+ if isinstance (value , float ):
280+ return f"{ value :.3g} "
281+ s = str (value )
282+ return s if len (s ) <= 30 else s [:27 ] + "..."
283+
284+ _NODE_STYLES : dict [type , dict ] = {
285+ Gate : {"shape" : "diamond" , "style" : "filled" , "fillcolor" : "#d0e8f5" },
286+ Arbiter : {"shape" : "box" , "style" : "filled" , "fillcolor" : "#c8e6c9" },
287+ Output : {"shape" : "ellipse" , "style" : "filled" , "fillcolor" : "#fff9c4" },
288+ }
289+ _DEFAULT_STYLE = {"shape" : "box" , "style" : "filled" , "fillcolor" : "#ffe0b2" }
290+ _DIM = {
291+ "style" : "filled" ,
292+ "fillcolor" : "#f0f0f0" ,
293+ "color" : "#bbbbbb" ,
294+ "fontcolor" : "#aaaaaa" ,
295+ }
296+
297+ dot = graphviz .Digraph (name = self .name )
298+ dot .attr (
299+ label = self .name ,
300+ labelloc = "t" ,
301+ fontsize = "13" ,
302+ fontname = "Helvetica" ,
303+ rankdir = "TB" ,
304+ ranksep = "0.5" ,
305+ nodesep = "0.4" ,
306+ )
307+ dot .attr ("node" , fontname = "Helvetica" , fontsize = "11" )
308+ dot .attr ("edge" , fontname = "Helvetica" , fontsize = "10" )
309+
310+ # implicit input node pinned to the top
311+ top = graphviz .Digraph ()
312+ top .attr (rank = "min" )
313+ top .node (
314+ "__input__" ,
315+ "prompt\n response" ,
316+ shape = "box" ,
317+ style = "dashed" ,
318+ fillcolor = "white" ,
319+ color = "#888888" ,
320+ fontcolor = "#555555" ,
321+ )
322+ dot .subgraph (top )
323+
324+ # output terminal nodes pinned to the bottom
325+ bottom = graphviz .Digraph ()
326+ bottom .attr (rank = "max" )
327+ for output_name , output_node in self ._outputs .items ():
328+ attrs = dict (_NODE_STYLES [Output ])
329+ if traced :
330+ if output_node is final_output :
331+ attrs ["penwidth" ] = "2.5"
332+ else :
333+ attrs = dict (_DIM , shape = "ellipse" )
334+ bottom .node (output_name , ** attrs )
335+ dot .subgraph (bottom )
336+
337+ # processing nodes
338+ for node_name , node in self ._nodes .items ():
339+ base_style = next (
340+ (s for t , s in _NODE_STYLES .items () if isinstance (node , t )),
341+ _DEFAULT_STYLE ,
342+ )
343+ if traced and node_name not in node_outputs :
344+ attrs = dict (_DIM , shape = base_style .get ("shape" , "box" ))
345+ label = node_name
255346 else :
256- route_str = f" → { node .routes } "
257- print (f" [{ node_type :10s} ] { node_name } { route_str } " )
347+ attrs = dict (base_style )
348+ if traced :
349+ raw = node_outputs [node_name ] # type: ignore[index]
350+ label = f"{ node_name } \n { _format_output (raw )} "
351+ else :
352+ label = node_name
353+ dot .node (node_name , label , ** attrs )
354+
355+ # dashed edges from implicit input to root nodes
356+ for root in self ._root_nodes :
357+ dot .edge (
358+ "__input__" , root , style = "dashed" , color = "#888888" , arrowhead = "open"
359+ )
360+
361+ # edges between processing nodes
362+ for node_name , node in self ._nodes .items ():
363+ if isinstance (node , Gate ):
364+ for target in node .routes_true :
365+ t = target if isinstance (target , str ) else target .name
366+ hot = not traced or (node_name , t ) in traversed_edges # type: ignore[operator]
367+ dot .edge (
368+ node_name ,
369+ t ,
370+ label = " True" ,
371+ color = "#2e7d32" if hot else "#cccccc" ,
372+ fontcolor = "#2e7d32" if hot else "#cccccc" ,
373+ penwidth = "2" if hot and traced else "1" ,
374+ )
375+ for target in node .routes_false :
376+ t = target if isinstance (target , str ) else target .name
377+ hot = not traced or (node_name , t ) in traversed_edges # type: ignore[operator]
378+ dot .edge (
379+ node_name ,
380+ t ,
381+ label = " False" ,
382+ color = "#c62828" if hot else "#cccccc" ,
383+ fontcolor = "#c62828" if hot else "#cccccc" ,
384+ penwidth = "2" if hot and traced else "1" ,
385+ )
386+ elif isinstance (node , Arbiter ):
387+ for output in node .outputs ():
388+ hot = not traced or (node_name , output .name ) in traversed_edges # type: ignore[operator]
389+ dot .edge (
390+ node_name ,
391+ output .name ,
392+ color = "#555555" if hot else "#cccccc" ,
393+ penwidth = "2" if hot and traced else "1" ,
394+ )
395+ else :
396+ for target in node .routes :
397+ t = target if isinstance (target , str ) else target .name
398+ hot = not traced or (node_name , t ) in traversed_edges # type: ignore[operator]
399+ dot .edge (
400+ node_name ,
401+ t ,
402+ color = "#555555" if hot else "#cccccc" ,
403+ penwidth = "2" if hot and traced else "1" ,
404+ )
405+
406+ try :
407+ return Image (dot .pipe (format = "png" ))
408+ except graphviz .ExecutableNotFound :
409+ raise RuntimeError (
410+ "Graphviz system binaries not found. Install them with:\n "
411+ " macOS: brew install graphviz\n "
412+ " Ubuntu: apt-get install graphviz\n "
413+ " conda: conda install graphviz"
414+ ) from None
415+
416+ @requires_validate_and_build
417+ def visualize_run (self , ctx : EvalContext ):
418+ """Run the DAG on ctx and return a PNG with the executed path highlighted."""
419+ final_output , node_outputs , traversed_edges = self ._run_traced (ctx )
420+ return self .visualize (
421+ node_outputs = node_outputs ,
422+ traversed_edges = traversed_edges ,
423+ final_output = final_output ,
424+ )
258425
259426
260427class DAGAnnotator (Annotator ):
0 commit comments