@@ -1773,9 +1773,10 @@ def _find_used_names(node_list, node_indices):
17731773def check_for_non_recursivity (
17741774 node_indices : List [int ],
17751775 node_list : List [Optional [NodeProto ]],
1776- inputs : Sequence [str ],
1777- outputs : Sequence [str ],
1778- ):
1776+ inputs : Union [Set [str ], Sequence [str ]],
1777+ outputs : Union [Set [str ], Sequence [str ]],
1778+ exc : bool = True ,
1779+ ) -> List [int ]:
17791780 """
17801781 We need to check that any of this output is not required
17811782 by one input from the function itself, that would mean one node
@@ -1786,13 +1787,17 @@ def check_for_non_recursivity(
17861787 :param node_list: list of nodes
17871788 :param inputs: input names to consider
17881789 :param outputs: output names which cannot be involved in input names
1790+ :param exc: raise an exception as soon as possible it becomes impossible
1791+ :return: list of nodes to add to make the list of node consistence
1792+ with the list of inputs and outputs (they should be recomputed)
17891793 """
1790- orginal_set_inputs = set (inputs )
1791- set_inputs = set ( inputs )
1792- original_set_outputs = set (outputs )
1794+ orginal_set_inputs = inputs if isinstance ( inputs , set ) else set (inputs )
1795+ set_inputs = orginal_set_inputs . copy ( )
1796+ original_set_outputs = outputs if isinstance ( outputs , set ) else set (outputs )
17931797 subset = set (node_indices )
17941798 before_inputs = set ()
17951799 indexed_node = list (enumerate (node_list ))
1800+ additional_nodes : List [int ] = []
17961801 for ind , node in indexed_node [::- 1 ]:
17971802 if not node :
17981803 continue
@@ -1822,15 +1827,18 @@ def check_for_non_recursivity(
18221827 if att .type == onnx .AttributeProto .GRAPH :
18231828 before_inputs |= get_hidden_inputs (att .g )
18241829 if original_set_outputs & before_inputs :
1825- raise ValueError (
1826- f"Results { original_set_outputs & before_inputs } "
1827- f"are needed for inputs { inputs } "
1828- f"but also requires { outputs } which is not allowed."
1829- )
1830+ if exc :
1831+ raise ValueError (
1832+ f"Results { original_set_outputs & before_inputs } "
1833+ f"are needed for inputs { inputs } "
1834+ f"but also requires { outputs } which is not allowed."
1835+ )
1836+ additional_nodes .append (ind )
1837+ return additional_nodes
18301838
18311839
18321840def _select_nodes_from_metadata_with_regex (
1833- model : ModelProto , prefix : str , regex : str
1841+ model : ModelProto , prefix : Union [ str , Tuple [ str , ...]] , regex : str
18341842) -> Tuple [Dict [str , List [int ]], Set [str ]]:
18351843 reg = re .compile (regex )
18361844 unique_values = set ()
@@ -1860,6 +1868,7 @@ def make_model_with_local_functions(
18601868 regex : str = ".*[.]layers[.][0-9]+[.]forward$" ,
18611869 domain : str = "local_function" ,
18621870 metadata_key_prefix : Union [str , Tuple [str , ...]] = ("namespace" , "source[" ),
1871+ allow_extensions : bool = True ,
18631872 verbose : int = 0 ,
18641873) -> ModelProto :
18651874 """
@@ -1874,6 +1883,8 @@ def make_model_with_local_functions(
18741883 :param domain: function domain
18751884 :param metadata_keys: list of metadata keys to consider,
18761885 every value is split into multiple ones.
1886+ :param allow_extensions: allows the function to take nodes outside
1887+ a partition if there are not already inside another partition
18771888 :param verbose: verbosity
18781889 :return: model proto
18791890 """
@@ -1900,40 +1911,82 @@ def make_model_with_local_functions(
19001911 print (" ..." )
19011912 functions = []
19021913 new_nodes : List [Optional [NodeProto ]] = list (model .graph .node )
1903- for key , node_indices in unique .items ():
1904- function_name = key .strip ().replace ("." , "_" )
1905- if verbose :
1906- print (
1907- f"[make_model_with_local_functions] move { len (node_indices )} "
1908- f"nodes in partition { function_name !r} "
1914+ processed = {}
1915+ unique_as_set = {k : set (v ) for k , v in unique .items ()}
1916+ while len (processed ) < len (unique ):
1917+ for key , node_indices in unique .items ():
1918+ if key in processed :
1919+ # already processed
1920+ continue
1921+ function_name = key .strip ().replace ("." , "_" )
1922+ if verbose :
1923+ print (
1924+ f"[make_model_with_local_functions] move { len (node_indices )} "
1925+ f"nodes in partition { key !r} (function={ function_name !r} )"
1926+ )
1927+ outputs = _find_used_names (new_nodes , node_indices )
1928+ # pyrefly: ignore[bad-assignment]
1929+ function_nodes : List [NodeProto ] = [
1930+ new_nodes [i ] for i in node_indices if new_nodes [i ]
1931+ ]
1932+
1933+ function_inputs = unknown_names_within_nodes (function_nodes )
1934+ additional_nodes = check_for_non_recursivity (
1935+ node_indices , new_nodes , function_inputs , outputs , exc = False
19091936 )
1910- outputs = _find_used_names (new_nodes , node_indices )
1911- function_nodes = [new_nodes [i ] for i in node_indices if new_nodes [i ]]
1912-
1913- check_for_non_recursivity (
1914- node_indices , model .graph .node , unknown_names_within_nodes (function_nodes ), outputs
1915- )
1937+ if additional_nodes :
1938+ if not allow_extensions :
1939+ raise ValueError (
1940+ f"Function for key={ key !r} cannot be added because "
1941+ f"it must steal a node outside the partition, node ids "
1942+ f"{ additional_nodes } are needed for inputs { function_inputs } "
1943+ f"but also requires { outputs } which is not allowed."
1944+ )
1945+ # Additional nodes are needed to make the function consistence.
1946+ # We check they are not in conflict with other partitions not
1947+ # yet processed.
1948+ set_add = set (additional_nodes )
1949+ for k , v in unique_as_set .items ():
1950+ if v & set_add :
1951+ raise ValueError (
1952+ f"Function for key={ key !r} cannot be added because "
1953+ f"it is conflict with other key { k !r} with node ids "
1954+ f"{ set_add & v } are needed for inputs { function_inputs } "
1955+ f"but also requires { outputs } which is not allowed."
1956+ )
1957+ # If no exception, everything is fine, let's add the nodes.
1958+ node_indices .extend (additional_nodes )
1959+ node_indices [:] = sorted (node_indices )
1960+ # Inputs and outputs needed to be recomputed. Let's do that in another
1961+ # iteration.
1962+ if verbose :
1963+ print (
1964+ f"[make_model_with_local_functions] add { len (additional_nodes )} "
1965+ f"nodes in partition { key !r} "
1966+ )
1967+ continue
19161968
1917- lf = make_subfunction (
1918- function_name ,
1919- function_nodes ,
1920- model .opset_import ,
1921- outputs ,
1922- domain = domain ,
1923- )
1969+ lf = make_subfunction (
1970+ function_name ,
1971+ function_nodes ,
1972+ model .opset_import ,
1973+ outputs ,
1974+ domain = domain ,
1975+ )
19241976
1925- check_for_non_recursivity (node_indices , model . graph . node , lf .input , lf .output )
1977+ check_for_non_recursivity (node_indices , new_nodes , lf .input , lf .output )
19261978
1927- if verbose :
1928- print (
1929- f"[make_model_with_local_functions] add function { function_name } "
1930- f"({ ', ' .join (lf .input )} ) -> { ', ' .join (lf .input )} "
1931- )
1932- functions .append (lf )
1933- maxi = max (node_indices )
1934- for i in node_indices :
1935- new_nodes [i ] = None
1936- new_nodes [maxi ] = oh .make_node (lf .name , lf .input , lf .output , domain = lf .domain )
1979+ if verbose :
1980+ print (
1981+ f"[make_model_with_local_functions] add function { function_name } "
1982+ f"({ ', ' .join (lf .input )} ) -> { ', ' .join (lf .input )} "
1983+ )
1984+ functions .append (lf )
1985+ maxi = max (node_indices )
1986+ for i in node_indices :
1987+ new_nodes [i ] = None
1988+ new_nodes [maxi ] = oh .make_node (lf .name , lf .input , lf .output , domain = lf .domain )
1989+ processed [key ] = lf
19371990
19381991 return oh .make_model (
19391992 oh .make_graph (
0 commit comments