@@ -95,7 +95,7 @@ def _validate_and_build(self) -> None:
9595 for node_name , node in self ._nodes .items ():
9696 for target in node .all_routes ():
9797 if target not in self ._nodes and not isinstance (
98- target , self ._output_type
98+ target , self .output_type
9999 ):
100100 raise ValueError (
101101 f"Node { node_name } routes to unregistered node { target } or incompatible output."
@@ -132,9 +132,9 @@ def _validate_and_build(self) -> None:
132132 for terminal in terminal_nodes :
133133 node = self ._nodes [terminal ]
134134 if isinstance (node , Arbiter ):
135- if not issubclass (node .output_type , self ._output_type ):
135+ if not issubclass (node .output_type , self .output_type ):
136136 raise ValueError (
137- f"Terminal node { terminal } has output_type { node .output_type .__name__ } , which is not compatible with the DAG's output_type { self ._output_type .__name__ } ."
137+ f"Terminal node { terminal } has output_type { node .output_type .__name__ } , which is not compatible with the DAG's output_type { self .output_type .__name__ } ."
138138 )
139139
140140 # build predecessors
@@ -267,7 +267,7 @@ def total_costs(self) -> dict[str, float]:
267267 )
268268
269269 base_path = " -> " .join (path )
270- path_costs [f"{ base_path } -> Out ({ self ._output_type .__name__ } )" ] = total
270+ path_costs [f"{ base_path } -> Out ({ self .output_type .__name__ } )" ] = total
271271
272272 return path_costs
273273
@@ -282,6 +282,8 @@ def _visualize(
282282
283283 When node_outputs/traversed_edges/final_output are provided (via visualize_run),
284284 the hot path is highlighted and each node shows its output value.
285+
286+ NOTE: this helper method is vibe-coded and provided as-is.
285287 """
286288 import graphviz
287289 from IPython .display import Image
@@ -302,7 +304,11 @@ def _visualize(
302304 "style" : "filled,rounded,dashed" ,
303305 "fillcolor" : "#dcedc8" ,
304306 }
305- _DEFAULT_STYLE = {"shape" : "rectangle" , "style" : "filled" , "fillcolor" : "#eeeeee" }
307+ _DEFAULT_STYLE = {
308+ "shape" : "rectangle" ,
309+ "style" : "filled" ,
310+ "fillcolor" : "#eeeeee" ,
311+ }
306312 _DIM = {
307313 "style" : "filled" ,
308314 "fillcolor" : "#f0f0f0" ,
@@ -392,12 +398,14 @@ def _truncate(s: str, n: int = 24) -> str:
392398 attrs ["penwidth" ] = "2.5"
393399 else :
394400 attrs = dict (_DIM , shape = "rectangle" , style = "filled,rounded" )
395- bottom .node (out_name , repr (out_inst ), fontsize = _fontsize (repr (out_inst )), ** attrs )
401+ bottom .node (
402+ out_name , repr (out_inst ), fontsize = _fontsize (repr (out_inst )), ** attrs
403+ )
396404
397405 # synthetic output type node for Arbiters
398406 if has_arbiter :
399- output_node_id = f"__output_{ self ._output_type .__name__ } __"
400- output_label = f"{ self ._output_type .__name__ } (?)"
407+ output_node_id = f"__output_{ self .output_type .__name__ } __"
408+ output_label = f"{ self .output_type .__name__ } (?)"
401409 attrs = dict (_OUTPUT_TYPE_STYLE )
402410 if traced :
403411 if not final_from_direct and final_output is not None :
@@ -406,7 +414,9 @@ def _truncate(s: str, n: int = 24) -> str:
406414 output_label = repr (final_output )
407415 elif final_from_direct :
408416 attrs = dict (_DIM , shape = "rectangle" , style = "filled,rounded" )
409- bottom .node (output_node_id , output_label , fontsize = _fontsize (output_label ), ** attrs )
417+ bottom .node (
418+ output_node_id , output_label , fontsize = _fontsize (output_label ), ** attrs
419+ )
410420
411421 dot .subgraph (bottom )
412422
@@ -433,7 +443,11 @@ def _truncate(s: str, n: int = 24) -> str:
433443 attrs ["penwidth" ] = "2.5"
434444 else :
435445 label = node_name
436- _fill = 0.45 if isinstance (node , Gate ) else 0.65 if isinstance (node , Arbiter ) else 0.8
446+ _fill = (
447+ 0.45
448+ if isinstance (node , Gate )
449+ else 0.65 if isinstance (node , Arbiter ) else 0.8
450+ )
437451 dot .node (node_name , label , fontsize = _fontsize (label , fill = _fill ), ** attrs )
438452
439453 # edges from implicit input to root nodes
@@ -466,7 +480,7 @@ def _truncate(s: str, n: int = 24) -> str:
466480 penwidth = "2" if hot and traced else "1" ,
467481 )
468482 elif isinstance (node , Arbiter ):
469- output_node_id = f"__output_{ self ._output_type .__name__ } __"
483+ output_node_id = f"__output_{ self .output_type .__name__ } __"
470484 hot = not traced or node_name in (node_outputs or {})
471485 dot .edge (
472486 node_name ,
0 commit comments