@@ -1355,6 +1355,20 @@ def _mkv_(name, itype, irank):
13551355 )
13561356
13571357
1358+ def unknown_names_within_nodes (nodes : List [NodeProto ]) -> Set [str ]:
1359+ """Returns the list of unkonwn results from a list of nodes."""
1360+ not_known : Set [str ] = set ()
1361+ for node in nodes [::- 1 ]:
1362+ not_known -= {o for o in node .output if o }
1363+ not_known |= {i for i in node .input if i }
1364+ if node .op_type in {"Scan" , "If" , "Loop" }:
1365+ # there are hidden inputs
1366+ for att in node .attribute :
1367+ if att .type == onnx .AttributeProto .GRAPH :
1368+ not_known |= get_hidden_inputs (att .g )
1369+ return not_known
1370+
1371+
13581372def make_subfunction (
13591373 name : str ,
13601374 nodes : List [NodeProto ],
@@ -1374,21 +1388,12 @@ def make_subfunction(
13741388 :param domain: function domain
13751389 :return: model proto
13761390 """
1377- not_known : Set [str ] = set ()
1378- for node in nodes [::- 1 ]:
1379- not_known -= {o for o in node .output if o }
1380- not_known |= {i for i in node .input if i }
1381- if node .op_type in {"Scan" , "If" , "Loop" }:
1382- # there are hidden inputs
1383- for att in node .attribute :
1384- if att .type == onnx .AttributeProto .GRAPH :
1385- not_known |= get_hidden_inputs (att .g )
13861391
13871392 return oh .make_function (
13881393 domain ,
13891394 name ,
13901395 nodes = nodes ,
1391- inputs = sorted (not_known ),
1396+ inputs = sorted (unknown_names_within_nodes ( nodes ) ),
13921397 outputs = output_names ,
13931398 opset_imports = opset_imports ,
13941399 )
@@ -1775,8 +1780,6 @@ def check_for_non_recursivity(
17751780 needs an output of the function and is also required by the function:
17761781 it is probably missing from the initial set.
17771782
1778-
1779-
17801783 :param node_list: list of nodes
17811784 :param inputs: input names to consider
17821785 :param outputs: output names which cannot be involved in input names
0 commit comments