@@ -202,7 +202,7 @@ def default_output(self):
202202 return self .outputs [do ]
203203
204204 def __str__ (self ):
205- return op_as_string (self .inputs , self )
205+ return node_as_string (self .inputs , self )
206206
207207 def __repr__ (self ):
208208 return str (self )
@@ -1409,8 +1409,11 @@ def compute_deps(obj):
14091409default_leaf_formatter = str
14101410
14111411
1412- def default_node_formatter (op , argstrings ):
1413- return f"{ op .op } ({ ', ' .join (argstrings )} )"
1412+ def default_node_formatter (node , input_strs , output_strs = None ):
1413+ if output_strs :
1414+ return f"{ ', ' .join (output_strs )} <- { node .op } ({ ', ' .join (input_strs )} )"
1415+ else :
1416+ return f"{ node .op } ({ ', ' .join (input_strs )} )"
14141417
14151418
14161419def io_connection_pattern (inputs , outputs ):
@@ -1479,12 +1482,16 @@ def io_connection_pattern(inputs, outputs):
14791482 return global_connection_pattern
14801483
14811484
1482- def op_as_string (
1483- i , op , leaf_formatter = default_leaf_formatter , node_formatter = default_node_formatter
1485+ def node_as_string (
1486+ inputs ,
1487+ node ,
1488+ leaf_formatter = default_leaf_formatter ,
1489+ node_formatter = default_node_formatter ,
14841490):
14851491 """Return a function that returns a string representation of the subgraph between `i` and :attr:`op.inputs`"""
1486- strs = as_string (i , op .inputs , leaf_formatter , node_formatter )
1487- return node_formatter (op , strs )
1492+ in_strs = as_string (inputs , node .inputs , leaf_formatter , node_formatter )
1493+ out_strs = as_string (node .outputs , node .outputs , leaf_formatter , node_formatter )
1494+ return node_formatter (node , in_strs , out_strs )
14881495
14891496
14901497def as_string (
0 commit comments