|
4 | 4 | from onnx_diagnostic.ext_test_case import ExtTestCase |
5 | 5 | from onnx_diagnostic.investigate.input_observer import ( |
6 | 6 | InputObserver, |
7 | | - infer_dynamic_dimensions, |
| 7 | + _infer_dynamic_dimensions, |
8 | 8 | ) |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class TestInputObserver(ExtTestCase): |
12 | 12 | def test_infer_dynamic_dimensions(self): |
13 | | - self.assertEqual([2], infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)])) |
14 | | - self.assertEqual([0, 2], infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)])) |
| 13 | + self.assertEqual([2], _infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)])) |
| 14 | + self.assertEqual([0, 2], _infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)])) |
15 | 15 |
|
16 | 16 | def test_io_captured_args(self): |
17 | 17 | class Model(torch.nn.Module): |
@@ -42,6 +42,36 @@ def forward(self, x, y): |
42 | 42 | self.assertIsInstance(args, tuple) |
43 | 43 | self.assertEqual(2, len(args)) |
44 | 44 |
|
| 45 | + def test_io_captured_not_forward(self): |
| 46 | + class Model(torch.nn.Module): |
| 47 | + def notforward(self, w): |
| 48 | + return w.abs() |
| 49 | + |
| 50 | + def forward(self, x, y): |
| 51 | + return x + self.notforward(y) |
| 52 | + |
| 53 | + inputs = [ |
| 54 | + (torch.randn((5, 6)), torch.randn((1, 6))), |
| 55 | + (torch.randn((7, 7)), torch.randn((1, 7))), |
| 56 | + (torch.randn((7, 8)), torch.randn((1, 8))), |
| 57 | + (torch.randn((7, 9)), torch.randn((1, 9))), |
| 58 | + ] |
| 59 | + |
| 60 | + model = Model() |
| 61 | + observer = InputObserver() |
| 62 | + with observer(model, method_name="notforward"): |
| 63 | + for args in inputs: |
| 64 | + model(*args) |
| 65 | + self.assertEqual(len(observer.info), 3) |
| 66 | + for i in range(3): |
| 67 | + self.assertEqual(len(observer.info.flat_outputs[i]), 1) |
| 68 | + |
| 69 | + cst = torch.export.Dim.DYNAMIC |
| 70 | + self.assertEqual(({1: cst},), observer.infer_dynamic_shapes()) |
| 71 | + args = observer.infer_arguments() |
| 72 | + self.assertIsInstance(args, tuple) |
| 73 | + self.assertEqual(1, len(args)) |
| 74 | + |
45 | 75 | def test_io_captured_kwargs(self): |
46 | 76 | class Model(torch.nn.Module): |
47 | 77 | def forward(self, x, y): |
@@ -516,6 +546,56 @@ def forward(self, x, custom=None): |
516 | 546 | model(*args) |
517 | 547 | self.assertEqual(expected, observer.infer_dynamic_shapes()) |
518 | 548 |
|
| 549 | + def test_io_captured_args_kwargs_dynamic_batch(self): |
| 550 | + class Model(torch.nn.Module): |
| 551 | + def forward(self, x, y, z=None, w=None): |
| 552 | + r = x + y |
| 553 | + if z is not None: |
| 554 | + r += z |
| 555 | + if w is not None: |
| 556 | + r += w |
| 557 | + return r |
| 558 | + |
| 559 | + inputs = [ |
| 560 | + ( |
| 561 | + (torch.randn((5, 6)), torch.randn((1, 6))), |
| 562 | + dict(z=torch.randn((5, 6)), w=torch.randn((1, 6))), |
| 563 | + ), |
| 564 | + ( |
| 565 | + (torch.randn((5, 7)), torch.randn((1, 7))), |
| 566 | + dict(z=torch.randn((5, 7)), w=torch.randn((1, 7))), |
| 567 | + ), |
| 568 | + ( |
| 569 | + (torch.randn((5, 8)), torch.randn((1, 8))), |
| 570 | + dict(z=torch.randn((5, 8)), w=torch.randn((1, 8))), |
| 571 | + ), |
| 572 | + ( |
| 573 | + (torch.randn((5, 9)), torch.randn((1, 9))), |
| 574 | + dict(z=torch.randn((5, 9)), w=torch.randn((1, 9))), |
| 575 | + ), |
| 576 | + ] |
| 577 | + |
| 578 | + model = Model() |
| 579 | + expected = [model(*args, **kwargs) for args, kwargs in inputs] |
| 580 | + observer = InputObserver() |
| 581 | + with observer(model): |
| 582 | + for args, kwargs in inputs: |
| 583 | + model(*args, **kwargs) |
| 584 | + self.assertEqual(len(observer.info), 3) |
| 585 | + for i in range(3): |
| 586 | + self.assertEqual(len(observer.info.flat_outputs[i]), 1) |
| 587 | + torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0]) |
| 588 | + |
| 589 | + cst = torch.export.Dim.DYNAMIC |
| 590 | + self.assertEqual( |
| 591 | + dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}), |
| 592 | + observer.infer_dynamic_shapes(add_batch_dimension_for={0, "z"}), |
| 593 | + ) |
| 594 | + self.assertEqual( |
| 595 | + dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}), |
| 596 | + observer.infer_dynamic_shapes(add_batch_dimension_for={"x", "z"}), |
| 597 | + ) |
| 598 | + |
519 | 599 |
|
520 | 600 | if __name__ == "__main__": |
521 | 601 | unittest.main(verbosity=2) |
0 commit comments