Skip to content

Commit 821fa76

Browse files
committed
fix partition
1 parent 0c3e47e commit 821fa76

2 files changed

Lines changed: 112 additions & 54 deletions

File tree

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -892,26 +892,31 @@ def test_make_model_with_local_functions_3(self):
892892
meta = node.metadata_props.add()
893893
meta.key = f"source[{i_node}]"
894894
meta.value = "LLL"
895+
self.assertRaise(
896+
lambda: make_model_with_local_functions(
897+
model,
898+
"^LLL$",
899+
metadata_key_prefix="source[",
900+
verbose=1,
901+
allow_extensions=False,
902+
),
903+
ValueError,
904+
)
895905
new_model = make_model_with_local_functions(
896906
model, "^LLL$", metadata_key_prefix="source[", verbose=1
897907
)
898908
check_model(new_model)
899909
self.assertEqual(len(new_model.functions), 1)
900910
p = pretty_onnx(new_model)
901-
self.assertIn("LLL0[local_function]", p)
902-
self.assertIn("LLL1[local_function]", p)
911+
self.assertIn("LLL[local_function]", p)
903912

904-
self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input)
905-
self.assertEqual(["xm1"], new_model.functions[0].output)
906-
self.assertEqual("LLL0", new_model.functions[0].name)
913+
self.assertEqual(
914+
["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input
915+
)
916+
self.assertEqual(["xm"], new_model.functions[0].output)
917+
self.assertEqual("LLL", new_model.functions[0].name)
907918
self.assertEqual("local_function", new_model.functions[0].domain)
908-
self.assertEqual(len(new_model.functions[0].node), 3)
909-
910-
self.assertEqual(["Y", "shape2"], new_model.functions[1].input)
911-
self.assertEqual(["xm2c"], new_model.functions[1].output)
912-
self.assertEqual("LLL1", new_model.functions[1].name)
913-
self.assertEqual("local_function", new_model.functions[1].domain)
914-
self.assertEqual(len(new_model.functions[1].node), 1)
919+
self.assertEqual(len(new_model.functions[0].node), 6)
915920

916921
check_model(new_model)
917922

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,9 +1773,10 @@ def _find_used_names(node_list, node_indices):
17731773
def check_for_non_recursivity(
17741774
node_indices: List[int],
17751775
node_list: List[Optional[NodeProto]],
1776-
inputs: Sequence[str],
1777-
outputs: Sequence[str],
1778-
):
1776+
inputs: Union[Set[str], Sequence[str]],
1777+
outputs: Union[Set[str], Sequence[str]],
1778+
exc: bool = True,
1779+
) -> List[int]:
17791780
"""
17801781
We need to check that any of this output is not required
17811782
by one input from the function itself, that would mean one node
@@ -1786,13 +1787,17 @@ def check_for_non_recursivity(
17861787
:param node_list: list of nodes
17871788
:param inputs: input names to consider
17881789
:param outputs: output names which cannot be involved in input names
1790+
:param exc: raise an exception as soon as possible it becomes impossible
1791+
:return: list of nodes to add to make the list of node consistence
1792+
with the list of inputs and outputs (they should be recomputed)
17891793
"""
1790-
orginal_set_inputs = set(inputs)
1791-
set_inputs = set(inputs)
1792-
original_set_outputs = set(outputs)
1794+
orginal_set_inputs = inputs if isinstance(inputs, set) else set(inputs)
1795+
set_inputs = orginal_set_inputs.copy()
1796+
original_set_outputs = outputs if isinstance(outputs, set) else set(outputs)
17931797
subset = set(node_indices)
17941798
before_inputs = set()
17951799
indexed_node = list(enumerate(node_list))
1800+
additional_nodes: List[int] = []
17961801
for ind, node in indexed_node[::-1]:
17971802
if not node:
17981803
continue
@@ -1822,15 +1827,18 @@ def check_for_non_recursivity(
18221827
if att.type == onnx.AttributeProto.GRAPH:
18231828
before_inputs |= get_hidden_inputs(att.g)
18241829
if original_set_outputs & before_inputs:
1825-
raise ValueError(
1826-
f"Results {original_set_outputs & before_inputs} "
1827-
f"are needed for inputs {inputs} "
1828-
f"but also requires {outputs} which is not allowed."
1829-
)
1830+
if exc:
1831+
raise ValueError(
1832+
f"Results {original_set_outputs & before_inputs} "
1833+
f"are needed for inputs {inputs} "
1834+
f"but also requires {outputs} which is not allowed."
1835+
)
1836+
additional_nodes.append(ind)
1837+
return additional_nodes
18301838

18311839

18321840
def _select_nodes_from_metadata_with_regex(
1833-
model: ModelProto, prefix: str, regex: str
1841+
model: ModelProto, prefix: Union[str, Tuple[str, ...]], regex: str
18341842
) -> Tuple[Dict[str, List[int]], Set[str]]:
18351843
reg = re.compile(regex)
18361844
unique_values = set()
@@ -1860,6 +1868,7 @@ def make_model_with_local_functions(
18601868
regex: str = ".*[.]layers[.][0-9]+[.]forward$",
18611869
domain: str = "local_function",
18621870
metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["),
1871+
allow_extensions: bool = True,
18631872
verbose: int = 0,
18641873
) -> ModelProto:
18651874
"""
@@ -1874,6 +1883,8 @@ def make_model_with_local_functions(
18741883
:param domain: function domain
18751884
:param metadata_keys: list of metadata keys to consider,
18761885
every value is split into multiple ones.
1886+
:param allow_extensions: allows the function to take nodes outside
1887+
a partition if there are not already inside another partition
18771888
:param verbose: verbosity
18781889
:return: model proto
18791890
"""
@@ -1900,40 +1911,82 @@ def make_model_with_local_functions(
19001911
print(" ...")
19011912
functions = []
19021913
new_nodes: List[Optional[NodeProto]] = list(model.graph.node)
1903-
for key, node_indices in unique.items():
1904-
function_name = key.strip().replace(".", "_")
1905-
if verbose:
1906-
print(
1907-
f"[make_model_with_local_functions] move {len(node_indices)} "
1908-
f"nodes in partition {function_name!r}"
1914+
processed = {}
1915+
unique_as_set = {k: set(v) for k, v in unique.items()}
1916+
while len(processed) < len(unique):
1917+
for key, node_indices in unique.items():
1918+
if key in processed:
1919+
# already processed
1920+
continue
1921+
function_name = key.strip().replace(".", "_")
1922+
if verbose:
1923+
print(
1924+
f"[make_model_with_local_functions] move {len(node_indices)} "
1925+
f"nodes in partition {key!r} (function={function_name!r})"
1926+
)
1927+
outputs = _find_used_names(new_nodes, node_indices)
1928+
# pyrefly: ignore[bad-assignment]
1929+
function_nodes: List[NodeProto] = [
1930+
new_nodes[i] for i in node_indices if new_nodes[i]
1931+
]
1932+
1933+
function_inputs = unknown_names_within_nodes(function_nodes)
1934+
additional_nodes = check_for_non_recursivity(
1935+
node_indices, new_nodes, function_inputs, outputs, exc=False
19091936
)
1910-
outputs = _find_used_names(new_nodes, node_indices)
1911-
function_nodes = [new_nodes[i] for i in node_indices if new_nodes[i]]
1912-
1913-
check_for_non_recursivity(
1914-
node_indices, model.graph.node, unknown_names_within_nodes(function_nodes), outputs
1915-
)
1937+
if additional_nodes:
1938+
if not allow_extensions:
1939+
raise ValueError(
1940+
f"Function for key={key!r} cannot be added because "
1941+
f"it must steal a node outside the partition, node ids "
1942+
f"{additional_nodes} are needed for inputs {function_inputs} "
1943+
f"but also requires {outputs} which is not allowed."
1944+
)
1945+
# Additional nodes are needed to make the function consistence.
1946+
# We check they are not in conflict with other partitions not
1947+
# yet processed.
1948+
set_add = set(additional_nodes)
1949+
for k, v in unique_as_set.items():
1950+
if v & set_add:
1951+
raise ValueError(
1952+
f"Function for key={key!r} cannot be added because "
1953+
f"it is conflict with other key {k!r} with node ids "
1954+
f"{set_add & v} are needed for inputs {function_inputs} "
1955+
f"but also requires {outputs} which is not allowed."
1956+
)
1957+
# If no exception, everything is fine, let's add the nodes.
1958+
node_indices.extend(additional_nodes)
1959+
node_indices[:] = sorted(node_indices)
1960+
# Inputs and outputs needed to be recomputed. Let's do that in another
1961+
# iteration.
1962+
if verbose:
1963+
print(
1964+
f"[make_model_with_local_functions] add {len(additional_nodes)} "
1965+
f"nodes in partition {key!r}"
1966+
)
1967+
continue
19161968

1917-
lf = make_subfunction(
1918-
function_name,
1919-
function_nodes,
1920-
model.opset_import,
1921-
outputs,
1922-
domain=domain,
1923-
)
1969+
lf = make_subfunction(
1970+
function_name,
1971+
function_nodes,
1972+
model.opset_import,
1973+
outputs,
1974+
domain=domain,
1975+
)
19241976

1925-
check_for_non_recursivity(node_indices, model.graph.node, lf.input, lf.output)
1977+
check_for_non_recursivity(node_indices, new_nodes, lf.input, lf.output)
19261978

1927-
if verbose:
1928-
print(
1929-
f"[make_model_with_local_functions] add function {function_name}"
1930-
f"({', '.join(lf.input)}) -> {', '.join(lf.input)}"
1931-
)
1932-
functions.append(lf)
1933-
maxi = max(node_indices)
1934-
for i in node_indices:
1935-
new_nodes[i] = None
1936-
new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain)
1979+
if verbose:
1980+
print(
1981+
f"[make_model_with_local_functions] add function {function_name}"
1982+
f"({', '.join(lf.input)}) -> {', '.join(lf.input)}"
1983+
)
1984+
functions.append(lf)
1985+
maxi = max(node_indices)
1986+
for i in node_indices:
1987+
new_nodes[i] = None
1988+
new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain)
1989+
processed[key] = lf
19371990

19381991
return oh.make_model(
19391992
oh.make_graph(

0 commit comments

Comments
 (0)