|
23 | 23 | from executorch.backends.cadence.aot.reorder_ops import ( |
24 | 24 | AdvanceQuantizeOpAboveDefChainPass, |
25 | 25 | AdvanceQuantizeOpAboveDefInBranchPass, |
| 26 | + MoveSliceBeforePermutePass, |
26 | 27 | PostponeDequantizeOpBelowUseChainPass, |
27 | 28 | PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, |
28 | 29 | SinkOpsCloserToUsePass, |
@@ -633,3 +634,130 @@ def test_permute_view_chains_neg(self) -> None: |
633 | 634 | self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy) |
634 | 635 | self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy) |
635 | 636 | self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) |
| 637 | + |
| 638 | + |
| 639 | +class TestMoveSliceBeforePermutePass(unittest.TestCase): |
| 640 | + def test_basic_move(self) -> None: |
| 641 | + """permute → slice becomes slice → permute.""" |
| 642 | + builder = GraphBuilder() |
| 643 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) |
| 644 | + permuted = builder.call_operator( |
| 645 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 646 | + ) |
| 647 | + sliced = builder.call_operator( |
| 648 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 649 | + args=(permuted, 1, 0, 2, 1), |
| 650 | + ) |
| 651 | + builder.output([sliced]) |
| 652 | + original = builder.get_graph_module() |
| 653 | + |
| 654 | + result = transform_and_check_numerics( |
| 655 | + original, |
| 656 | + (torch.randn(2, 3, 4, 5),), |
| 657 | + MoveSliceBeforePermutePass(), |
| 658 | + ) |
| 659 | + self.assertTrue(result.modified) |
| 660 | + |
| 661 | + nodes = get_compute_nodes_in_gm(result.graph_module) |
| 662 | + self.assertEqual(len(nodes), 2) |
| 663 | + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) |
| 664 | + self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy) |
| 665 | + |
| 666 | + def test_multi_user_permute_no_change(self) -> None: |
| 667 | + """Permute with multiple users → no change (only single-user supported).""" |
| 668 | + builder = GraphBuilder() |
| 669 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) |
| 670 | + permuted = builder.call_operator( |
| 671 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 672 | + ) |
| 673 | + slice1 = builder.call_operator( |
| 674 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 675 | + args=(permuted, 1, 0, 2, 1), |
| 676 | + ) |
| 677 | + slice2 = builder.call_operator( |
| 678 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 679 | + args=(permuted, 2, 1, 3, 1), |
| 680 | + ) |
| 681 | + builder.output([slice1, slice2]) |
| 682 | + original = builder.get_graph_module() |
| 683 | + |
| 684 | + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) |
| 685 | + self.assertFalse(result.modified) |
| 686 | + |
| 687 | + def test_no_slice_users_no_change(self) -> None: |
| 688 | + """Permute with no slice users → no change.""" |
| 689 | + builder = GraphBuilder() |
| 690 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) |
| 691 | + permuted = builder.call_operator( |
| 692 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 693 | + ) |
| 694 | + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,)) |
| 695 | + builder.output([neg]) |
| 696 | + original = builder.get_graph_module() |
| 697 | + |
| 698 | + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) |
| 699 | + self.assertFalse(result.modified) |
| 700 | + |
| 701 | + def test_dim0_slice_large_reduction_moved(self) -> None: |
| 702 | + """Dim-0 slice removing >50% of data → profitable, moved.""" |
| 703 | + builder = GraphBuilder() |
| 704 | + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) |
| 705 | + permuted = builder.call_operator( |
| 706 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 707 | + ) |
| 708 | + sliced = builder.call_operator( |
| 709 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 710 | + args=(permuted, 0, 0, 2, 1), |
| 711 | + ) |
| 712 | + builder.output([sliced]) |
| 713 | + original = builder.get_graph_module() |
| 714 | + |
| 715 | + result = transform_and_check_numerics( |
| 716 | + original, |
| 717 | + (torch.randn(10, 3, 4, 5),), |
| 718 | + MoveSliceBeforePermutePass(), |
| 719 | + ) |
| 720 | + self.assertTrue(result.modified) |
| 721 | + |
| 722 | + nodes = get_compute_nodes_in_gm(result.graph_module) |
| 723 | + self.assertEqual(len(nodes), 2) |
| 724 | + self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy) |
| 725 | + self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy) |
| 726 | + |
| 727 | + def test_dim0_slice_small_reduction_not_moved(self) -> None: |
| 728 | + """Dim-0 slice removing <50% of data → not profitable, kept.""" |
| 729 | + builder = GraphBuilder() |
| 730 | + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) |
| 731 | + permuted = builder.call_operator( |
| 732 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 733 | + ) |
| 734 | + sliced = builder.call_operator( |
| 735 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 736 | + args=(permuted, 0, 0, 8, 1), |
| 737 | + ) |
| 738 | + builder.output([sliced]) |
| 739 | + original = builder.get_graph_module() |
| 740 | + |
| 741 | + result = cast(PassResult, MoveSliceBeforePermutePass()(original)) |
| 742 | + self.assertFalse(result.modified) |
| 743 | + |
| 744 | + def test_non_dim0_slice_always_moved(self) -> None: |
| 745 | + """Non-dim-0 slice → always profitable, moved regardless of reduction.""" |
| 746 | + builder = GraphBuilder() |
| 747 | + x = builder.placeholder("x", torch.randn(10, 3, 4, 5)) |
| 748 | + permuted = builder.call_operator( |
| 749 | + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) |
| 750 | + ) |
| 751 | + sliced = builder.call_operator( |
| 752 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 753 | + args=(permuted, 2, 0, 3, 1), |
| 754 | + ) |
| 755 | + builder.output([sliced]) |
| 756 | + original = builder.get_graph_module() |
| 757 | + |
| 758 | + result = transform_and_check_numerics( |
| 759 | + original, |
| 760 | + (torch.randn(10, 3, 4, 5),), |
| 761 | + MoveSliceBeforePermutePass(), |
| 762 | + ) |
| 763 | + self.assertTrue(result.modified) |
0 commit comments