Skip to content

Commit 0315031

Browse files
authored
investigation around input observer (#398)
* investigation around input observer * fix * more * disable two tests'
1 parent 37ab9b8 commit 0315031

4 files changed

Lines changed: 225 additions & 48 deletions

File tree

_unittests/ut_ci_models/test_ci_export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class TestCiExport(ExtTestCase):
1212
@hide_stdout()
13+
@requires_transformers("4.55")
1314
def test_main_qwen25_tiny_llm(self):
1415
main_qwen25(
1516
model_id="arnir0/Tiny-LLM",

_unittests/ut_investigate/test_input_observer.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from onnx_diagnostic.ext_test_case import ExtTestCase
55
from onnx_diagnostic.investigate.input_observer import (
66
InputObserver,
7-
infer_dynamic_dimensions,
7+
_infer_dynamic_dimensions,
88
)
99

1010

1111
class TestInputObserver(ExtTestCase):
1212
def test_infer_dynamic_dimensions(self):
13-
self.assertEqual([2], infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)]))
14-
self.assertEqual([0, 2], infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)]))
13+
self.assertEqual([2], _infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)]))
14+
self.assertEqual([0, 2], _infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)]))
1515

1616
def test_io_captured_args(self):
1717
class Model(torch.nn.Module):
@@ -42,6 +42,36 @@ def forward(self, x, y):
4242
self.assertIsInstance(args, tuple)
4343
self.assertEqual(2, len(args))
4444

45+
def test_io_captured_not_forward(self):
46+
class Model(torch.nn.Module):
47+
def notforward(self, w):
48+
return w.abs()
49+
50+
def forward(self, x, y):
51+
return x + self.notforward(y)
52+
53+
inputs = [
54+
(torch.randn((5, 6)), torch.randn((1, 6))),
55+
(torch.randn((7, 7)), torch.randn((1, 7))),
56+
(torch.randn((7, 8)), torch.randn((1, 8))),
57+
(torch.randn((7, 9)), torch.randn((1, 9))),
58+
]
59+
60+
model = Model()
61+
observer = InputObserver()
62+
with observer(model, method_name="notforward"):
63+
for args in inputs:
64+
model(*args)
65+
self.assertEqual(len(observer.info), 3)
66+
for i in range(3):
67+
self.assertEqual(len(observer.info.flat_outputs[i]), 1)
68+
69+
cst = torch.export.Dim.DYNAMIC
70+
self.assertEqual(({1: cst},), observer.infer_dynamic_shapes())
71+
args = observer.infer_arguments()
72+
self.assertIsInstance(args, tuple)
73+
self.assertEqual(1, len(args))
74+
4575
def test_io_captured_kwargs(self):
4676
class Model(torch.nn.Module):
4777
def forward(self, x, y):
@@ -516,6 +546,56 @@ def forward(self, x, custom=None):
516546
model(*args)
517547
self.assertEqual(expected, observer.infer_dynamic_shapes())
518548

549+
def test_io_captured_args_kwargs_dynamic_batch(self):
550+
class Model(torch.nn.Module):
551+
def forward(self, x, y, z=None, w=None):
552+
r = x + y
553+
if z is not None:
554+
r += z
555+
if w is not None:
556+
r += w
557+
return r
558+
559+
inputs = [
560+
(
561+
(torch.randn((5, 6)), torch.randn((1, 6))),
562+
dict(z=torch.randn((5, 6)), w=torch.randn((1, 6))),
563+
),
564+
(
565+
(torch.randn((5, 7)), torch.randn((1, 7))),
566+
dict(z=torch.randn((5, 7)), w=torch.randn((1, 7))),
567+
),
568+
(
569+
(torch.randn((5, 8)), torch.randn((1, 8))),
570+
dict(z=torch.randn((5, 8)), w=torch.randn((1, 8))),
571+
),
572+
(
573+
(torch.randn((5, 9)), torch.randn((1, 9))),
574+
dict(z=torch.randn((5, 9)), w=torch.randn((1, 9))),
575+
),
576+
]
577+
578+
model = Model()
579+
expected = [model(*args, **kwargs) for args, kwargs in inputs]
580+
observer = InputObserver()
581+
with observer(model):
582+
for args, kwargs in inputs:
583+
model(*args, **kwargs)
584+
self.assertEqual(len(observer.info), 3)
585+
for i in range(3):
586+
self.assertEqual(len(observer.info.flat_outputs[i]), 1)
587+
torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0])
588+
589+
cst = torch.export.Dim.DYNAMIC
590+
self.assertEqual(
591+
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
592+
observer.infer_dynamic_shapes(add_batch_dimension_for={0, "z"}),
593+
)
594+
self.assertEqual(
595+
dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}),
596+
observer.infer_dynamic_shapes(add_batch_dimension_for={"x", "z"}),
597+
)
598+
519599

520600
if __name__ == "__main__":
521601
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def add_test_methods(cls):
8787

8888
if (
8989
not reason
90-
and name in {"plot_export_tiny_llm.py"}
91-
and not has_transformers("4.51")
90+
and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"}
91+
and not has_transformers("4.55")
9292
):
93-
reason = "transformers<4.51"
93+
reason = "transformers<4.55"
9494

9595
if (
9696
not reason

0 commit comments

Comments
 (0)