Skip to content

Commit 3741f5f

Browse files
committed
Use output type property.
1 parent 2a38c67 commit 3741f5f

File tree

1 file changed

+25
-11
lines changed
  • src/modelplane/evaluator

1 file changed

+25
-11
lines changed

src/modelplane/evaluator/dag.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _validate_and_build(self) -> None:
9595
for node_name, node in self._nodes.items():
9696
for target in node.all_routes():
9797
if target not in self._nodes and not isinstance(
98-
target, self._output_type
98+
target, self.output_type
9999
):
100100
raise ValueError(
101101
f"Node {node_name} routes to unregistered node {target} or incompatible output."
@@ -132,9 +132,9 @@ def _validate_and_build(self) -> None:
132132
for terminal in terminal_nodes:
133133
node = self._nodes[terminal]
134134
if isinstance(node, Arbiter):
135-
if not issubclass(node.output_type, self._output_type):
135+
if not issubclass(node.output_type, self.output_type):
136136
raise ValueError(
137-
f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self._output_type.__name__}."
137+
f"Terminal node {terminal} has output_type {node.output_type.__name__}, which is not compatible with the DAG's output_type {self.output_type.__name__}."
138138
)
139139

140140
# build predecessors
@@ -267,7 +267,7 @@ def total_costs(self) -> dict[str, float]:
267267
)
268268

269269
base_path = " -> ".join(path)
270-
path_costs[f"{base_path} -> Out ({self._output_type.__name__})"] = total
270+
path_costs[f"{base_path} -> Out ({self.output_type.__name__})"] = total
271271

272272
return path_costs
273273

@@ -282,6 +282,8 @@ def _visualize(
282282
283283
When node_outputs/traversed_edges/final_output are provided (via visualize_run),
284284
the hot path is highlighted and each node shows its output value.
285+
286+
NOTE: this helper method is vibe-coded and provided as-is.
285287
"""
286288
import graphviz
287289
from IPython.display import Image
@@ -302,7 +304,11 @@ def _visualize(
302304
"style": "filled,rounded,dashed",
303305
"fillcolor": "#dcedc8",
304306
}
305-
_DEFAULT_STYLE = {"shape": "rectangle", "style": "filled", "fillcolor": "#eeeeee"}
307+
_DEFAULT_STYLE = {
308+
"shape": "rectangle",
309+
"style": "filled",
310+
"fillcolor": "#eeeeee",
311+
}
306312
_DIM = {
307313
"style": "filled",
308314
"fillcolor": "#f0f0f0",
@@ -392,12 +398,14 @@ def _truncate(s: str, n: int = 24) -> str:
392398
attrs["penwidth"] = "2.5"
393399
else:
394400
attrs = dict(_DIM, shape="rectangle", style="filled,rounded")
395-
bottom.node(out_name, repr(out_inst), fontsize=_fontsize(repr(out_inst)), **attrs)
401+
bottom.node(
402+
out_name, repr(out_inst), fontsize=_fontsize(repr(out_inst)), **attrs
403+
)
396404

397405
# synthetic output type node for Arbiters
398406
if has_arbiter:
399-
output_node_id = f"__output_{self._output_type.__name__}__"
400-
output_label = f"{self._output_type.__name__} (?)"
407+
output_node_id = f"__output_{self.output_type.__name__}__"
408+
output_label = f"{self.output_type.__name__} (?)"
401409
attrs = dict(_OUTPUT_TYPE_STYLE)
402410
if traced:
403411
if not final_from_direct and final_output is not None:
@@ -406,7 +414,9 @@ def _truncate(s: str, n: int = 24) -> str:
406414
output_label = repr(final_output)
407415
elif final_from_direct:
408416
attrs = dict(_DIM, shape="rectangle", style="filled,rounded")
409-
bottom.node(output_node_id, output_label, fontsize=_fontsize(output_label), **attrs)
417+
bottom.node(
418+
output_node_id, output_label, fontsize=_fontsize(output_label), **attrs
419+
)
410420

411421
dot.subgraph(bottom)
412422

@@ -433,7 +443,11 @@ def _truncate(s: str, n: int = 24) -> str:
433443
attrs["penwidth"] = "2.5"
434444
else:
435445
label = node_name
436-
_fill = 0.45 if isinstance(node, Gate) else 0.65 if isinstance(node, Arbiter) else 0.8
446+
_fill = (
447+
0.45
448+
if isinstance(node, Gate)
449+
else 0.65 if isinstance(node, Arbiter) else 0.8
450+
)
437451
dot.node(node_name, label, fontsize=_fontsize(label, fill=_fill), **attrs)
438452

439453
# edges from implicit input to root nodes
@@ -466,7 +480,7 @@ def _truncate(s: str, n: int = 24) -> str:
466480
penwidth="2" if hot and traced else "1",
467481
)
468482
elif isinstance(node, Arbiter):
469-
output_node_id = f"__output_{self._output_type.__name__}__"
483+
output_node_id = f"__output_{self.output_type.__name__}__"
470484
hot = not traced or node_name in (node_outputs or {})
471485
dot.edge(
472486
node_name,

0 commit comments

Comments
 (0)