Skip to content

Commit f1bc2d8

Browse files
authored
add function make_model_with_local_functions to partition a model into local functions (#394)
* add function to partition a model into local functions * changelogs * support prefixes * documentation and verbosity * speel * refly * fix doc
1 parent 83d3f41 commit f1bc2d8

13 files changed

Lines changed: 553 additions & 19 deletions

File tree

.github/workflows/ci.yml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,31 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57.6', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.56.2', '4.57.6', 'main']
2121
torch: ['2.9', 'main']
2222
exclude:
2323
- python: '3.10' # 3.10
2424
torch: 'main'
2525
- python: '3.10'
2626
torch: '2.9'
27-
- python: '3.10'
28-
transformers: 'main'
29-
- python: '3.10'
30-
transformers: '4.52.4'
3127
- python: '3.10'
3228
transformers: '4.55.4'
3329
- python: '3.10'
3430
transformers: '4.56.2'
3531
- python: '3.10'
3632
transformers: '4.57.6'
33+
- python: '3.10'
34+
transformers: 'main'
3735
- python: '3.11' # 3.11
3836
torch: 'main'
39-
- python: '3.11'
40-
transformers: 'main'
4137
- python: '3.11'
4238
transformers: '4.55.4'
4339
- python: '3.11'
4440
transformers: '4.56.2'
4541
- python: '3.11'
4642
transformers: '4.57.6'
43+
- python: '3.11'
44+
transformers: 'main'
4745
- python: '3.13' # 3.11
4846
torch: '2.9'
4947
- python: '3.13'

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.8.11
55
++++++
66

7+
* :pr:`394`: add function make_model_with_local_functions to partition a model into local functions
8+
79
0.8.10
810
++++++
911

_doc/api/api.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_doc/api/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ API of onnx_diagnostic
2020
:maxdepth: 1
2121
:caption: modules
2222

23-
api
23+
typing
2424
ext_test_case
2525

2626
.. automodule:: onnx_diagnostic

_doc/api/typing.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
onnx_diagnostic.typing
3+
======================
4+
5+
.. automodule:: onnx_diagnostic.typing
6+
:members:

_doc/cmds/_img_partition.png

33.6 KB
Loading

_doc/cmds/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ Command Lines
1111
compare
1212
config
1313
optimize
14+
partition
1415
sbs
1516
validate

_doc/cmds/partition.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
-m onnx_diagnostic partition ... move layer nodes in local functions
2+
====================================================================
3+
4+
The command line leverages the metadata added by the exporter.
5+
Every node is tagged with information indicating which part of the model
6+
it comes from. In particular the eky `namespace`:
7+
8+
::
9+
10+
transformers.models.llama.modeling_llama.LlamaForCausalLM/model:
11+
transformers.models.llama.modeling_llama.LlamaModel/model.layers.0:
12+
transformers.models.llama.modeling_llama.LlamaDecoderLayer/model.layers.0.self_attn:
13+
transformers.models.llama.modeling_llama.LlamaAttention/unsqueeze_15:
14+
aten.unsqueeze.default
15+
16+
Description
17+
+++++++++++
18+
19+
See :func:`onnx_diagnostic.helpers.onnx_helper.make_model_with_local_functions`.
20+
21+
.. runpython::
22+
23+
from onnx_diagnostic._command_lines_parser import get_parser_partition
24+
25+
get_parser_partition().print_help()
26+
27+
Example
28+
+++++++
29+
30+
.. code-block:: bash
31+
32+
python -m onnx_diagnostic partition arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx partition.onnx -r ".*[.]layers[.][0-9]+$" -v 1
33+
34+
This produces the following output:
35+
36+
::
37+
38+
-- load 'arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx'
39+
-- partition
40+
[make_model_with_local_functions] matched 1 partitions
41+
[make_model_with_local_functions] move 89 nodes in partition 'transformers_models_llama_modeling_llama_LlamaModel/model_layers_0'
42+
-- save into 'partition.onnx'
43+
-- done
44+
45+
The partitioned model includes the following node:
46+
47+
.. image:: _img_partition.png

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
shadowing_names,
2929
onnx_dtype_name,
3030
extract_subset_of_nodes,
31+
make_subfunction,
3132
make_submodel,
33+
make_model_with_local_functions,
3234
select_model_inputs_outputs,
3335
_enumerate_model_node_outputs,
36+
pretty_onnx,
3437
)
3538

3639
TFLOAT = TensorProto.FLOAT
@@ -537,6 +540,46 @@ def _type_rank_fn(name):
537540
check_model(new_model)
538541
self.check_ort(new_model)
539542

543+
def test_make_subfunction(self):
544+
model = oh.make_model(
545+
oh.make_graph(
546+
[
547+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
548+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
549+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
550+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
551+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
552+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
553+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
554+
],
555+
"dummy",
556+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
557+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
558+
[
559+
onh.from_array(
560+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
561+
),
562+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
563+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
564+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
565+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
566+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
567+
],
568+
),
569+
opset_imports=[oh.make_opsetid("", 18)],
570+
ir_version=9,
571+
)
572+
new_function = make_subfunction(
573+
"localf",
574+
model.graph.node[:4],
575+
opset_imports=model.opset_import,
576+
output_names=["xm1", "xm2c"],
577+
)
578+
self.assertIsInstance(new_function, FunctionProto)
579+
self.assertEqual(len(new_function.node), 4)
580+
self.assertEqual(new_function.output, ["xm1", "xm2c"])
581+
self.assertEqual(new_function.input, ["X", "Y", "shape1", "shape2", "un", "zero"])
582+
540583
def test_extract_subset_of_nodes_bigger(self):
541584
model = onnx.load(
542585
os.path.join(
@@ -670,6 +713,153 @@ def enumerate_model_tensors(model):
670713
got = sess.run(None, {"X": x})[0]
671714
self.assertEqual((x**2 + y).tolist(), got.tolist())
672715

716+
def test_make_model_with_local_functions(self):
717+
model = oh.make_model(
718+
oh.make_graph(
719+
[
720+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
721+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
722+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
723+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
724+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
725+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
726+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
727+
],
728+
"dummy",
729+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
730+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
731+
[
732+
onh.from_array(
733+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
734+
),
735+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
736+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
737+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
738+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
739+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
740+
],
741+
),
742+
opset_imports=[oh.make_opsetid("", 18)],
743+
ir_version=9,
744+
)
745+
for i_node in [0, 1, 2, 3]:
746+
node = model.graph.node[i_node]
747+
meta = node.metadata_props.add()
748+
meta.key = "namespace"
749+
meta.value = "LLL"
750+
new_model = make_model_with_local_functions(model, "^LLL$")
751+
check_model(model)
752+
self.assertEqual(len(new_model.functions), 1)
753+
self.assertEqual(
754+
["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input
755+
)
756+
self.assertEqual(["xm1", "xm2c"], new_model.functions[0].output)
757+
self.assertEqual("LLL", new_model.functions[0].name)
758+
self.assertEqual("local_function", new_model.functions[0].domain)
759+
self.assertIn("LLL[local_function]", pretty_onnx(new_model))
760+
check_model(new_model)
761+
762+
def test_make_model_with_local_functions_bug(self):
763+
model = oh.make_model(
764+
oh.make_graph(
765+
[
766+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
767+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
768+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
769+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
770+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
771+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
772+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
773+
],
774+
"dummy",
775+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
776+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
777+
[
778+
onh.from_array(
779+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
780+
),
781+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
782+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
783+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
784+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
785+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
786+
],
787+
),
788+
opset_imports=[oh.make_opsetid("", 18)],
789+
ir_version=9,
790+
)
791+
for i_node in [0, 2, 3, 4]:
792+
node = model.graph.node[i_node]
793+
meta = node.metadata_props.add()
794+
meta.key = "namespace"
795+
meta.value = "LLL"
796+
self.assertRaise(
797+
lambda: make_model_with_local_functions(model, "^LLL$"),
798+
ValueError,
799+
"Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', "
800+
"'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] "
801+
"which is not allowed.",
802+
)
803+
check_model(model)
804+
805+
@hide_stdout()
806+
def test_make_model_with_local_functions_2(self):
807+
model = oh.make_model(
808+
oh.make_graph(
809+
[
810+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
811+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
812+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
813+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
814+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
815+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
816+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
817+
],
818+
"dummy",
819+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
820+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
821+
[
822+
onh.from_array(
823+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
824+
),
825+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
826+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
827+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
828+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
829+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
830+
],
831+
),
832+
opset_imports=[oh.make_opsetid("", 18)],
833+
ir_version=9,
834+
)
835+
for i_node in [0, 1, 2, 3]:
836+
node = model.graph.node[i_node]
837+
meta = node.metadata_props.add()
838+
meta.key = f"source[{i_node}]"
839+
meta.value = f"LLL{i_node//3}"
840+
new_model = make_model_with_local_functions(
841+
model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1
842+
)
843+
check_model(model)
844+
self.assertEqual(len(new_model.functions), 2)
845+
p = pretty_onnx(new_model)
846+
self.assertIn("LLL0[local_function]", p)
847+
self.assertIn("LLL1[local_function]", p)
848+
849+
self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input)
850+
self.assertEqual(["xm1"], new_model.functions[0].output)
851+
self.assertEqual("LLL0", new_model.functions[0].name)
852+
self.assertEqual("local_function", new_model.functions[0].domain)
853+
self.assertEqual(len(new_model.functions[0].node), 3)
854+
855+
self.assertEqual(["Y", "shape2"], new_model.functions[1].input)
856+
self.assertEqual(["xm2c"], new_model.functions[1].output)
857+
self.assertEqual("LLL1", new_model.functions[1].name)
858+
self.assertEqual("local_function", new_model.functions[1].domain)
859+
self.assertEqual(len(new_model.functions[1].node), 1)
860+
861+
check_model(new_model)
862+
673863

674864
if __name__ == "__main__":
675865
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_parser_find,
1212
get_parser_lighten,
1313
get_parser_optimize,
14+
get_parser_partition,
1415
get_parser_print,
1516
get_parser_sbs,
1617
get_parser_stats,
@@ -186,6 +187,13 @@ def test_parser_optimize(self):
186187
text = st.getvalue()
187188
self.assertIn("default", text)
188189

190+
def test_parser_partition(self):
191+
st = StringIO()
192+
with redirect_stdout(st):
193+
get_parser_partition().print_help()
194+
text = st.getvalue()
195+
self.assertIn("regex", text)
196+
189197

190198
if __name__ == "__main__":
191199
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)