|
26 | 26 | MoveSliceBeforePermutePass, |
27 | 27 | PostponeDequantizeOpBelowUseChainPass, |
28 | 28 | PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, |
| 29 | + PropagateSlice, |
29 | 30 | SinkOpsCloserToUsePass, |
30 | 31 | ) |
31 | 32 | from executorch.backends.test.graph_builder import GraphBuilder |
@@ -761,3 +762,168 @@ def test_non_dim0_slice_always_moved(self) -> None: |
761 | 762 | MoveSliceBeforePermutePass(), |
762 | 763 | ) |
763 | 764 | self.assertTrue(result.modified) |
| 765 | + |
| 766 | + |
| 767 | +class TestPropagateSlice(unittest.TestCase): |
| 768 | + def test_swap_quantize_slice(self) -> None: |
| 769 | + builder = GraphBuilder() |
| 770 | + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) |
| 771 | + quant = builder.call_operator( |
| 772 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 773 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 774 | + ) |
| 775 | + sliced = builder.call_operator( |
| 776 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 777 | + args=(quant, 0, 0, 4, 2), |
| 778 | + ) |
| 779 | + builder.output([sliced]) |
| 780 | + gm = builder.get_graph_module() |
| 781 | + |
| 782 | + result = PropagateSlice().call(gm) |
| 783 | + |
| 784 | + self.assertTrue(result.modified) |
| 785 | + |
| 786 | + slice_nodes = gm.graph.find_nodes( |
| 787 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 788 | + ) |
| 789 | + self.assertEqual(len(slice_nodes), 1) |
| 790 | + slice_node = slice_nodes[0] |
| 791 | + self.assertEqual(slice_node.args[0].name, "x") |
| 792 | + self.assertEqual(list(slice_node.meta["val"].shape), [2, 60, 1, 1]) |
| 793 | + |
| 794 | + quant_nodes = gm.graph.find_nodes( |
| 795 | + op="call_function", |
| 796 | + target=exir_ops.edge.cadence.quantize_per_tensor.default, |
| 797 | + ) |
| 798 | + self.assertEqual(len(quant_nodes), 1) |
| 799 | + self.assertEqual(quant_nodes[0].args[0], slice_node) |
| 800 | + self.assertEqual(list(quant_nodes[0].meta["val"].shape), [2, 60, 1, 1]) |
| 801 | + |
| 802 | + def test_swap_dequantize_slice(self) -> None: |
| 803 | + builder = GraphBuilder() |
| 804 | + x = builder.placeholder( |
| 805 | + "x", torch.randint(0, 255, (4, 60, 4, 4), dtype=torch.uint8) |
| 806 | + ) |
| 807 | + dequant = builder.call_operator( |
| 808 | + exir_ops.edge.cadence.dequantize_per_tensor.default, |
| 809 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 810 | + ) |
| 811 | + sliced = builder.call_operator( |
| 812 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 813 | + args=(dequant, 0, 0, 4, 2), |
| 814 | + ) |
| 815 | + builder.output([sliced]) |
| 816 | + gm = builder.get_graph_module() |
| 817 | + |
| 818 | + result = PropagateSlice().call(gm) |
| 819 | + |
| 820 | + self.assertTrue(result.modified) |
| 821 | + |
| 822 | + slice_nodes = gm.graph.find_nodes( |
| 823 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 824 | + ) |
| 825 | + self.assertEqual(len(slice_nodes), 1) |
| 826 | + self.assertEqual(slice_nodes[0].args[0].name, "x") |
| 827 | + |
| 828 | + def test_step_2_through_quantize(self) -> None: |
| 829 | + builder = GraphBuilder() |
| 830 | + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) |
| 831 | + quant = builder.call_operator( |
| 832 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 833 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 834 | + ) |
| 835 | + sliced = builder.call_operator( |
| 836 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 837 | + args=(quant, 0, 0, 4, 2), |
| 838 | + ) |
| 839 | + builder.output([sliced]) |
| 840 | + gm = builder.get_graph_module() |
| 841 | + |
| 842 | + result = PropagateSlice().call(gm) |
| 843 | + |
| 844 | + self.assertTrue(result.modified) |
| 845 | + |
| 846 | + slice_nodes = gm.graph.find_nodes( |
| 847 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 848 | + ) |
| 849 | + self.assertEqual(len(slice_nodes), 1) |
| 850 | + self.assertEqual(slice_nodes[0].args[4], 2) |
| 851 | + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 60, 1, 1]) |
| 852 | + |
| 853 | + def test_non_batch_dim_slice(self) -> None: |
| 854 | + builder = GraphBuilder() |
| 855 | + x = builder.placeholder("x", torch.randn(4, 60, 4, 4)) |
| 856 | + quant = builder.call_operator( |
| 857 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 858 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 859 | + ) |
| 860 | + sliced = builder.call_operator( |
| 861 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 862 | + args=(quant, 1, 0, 30, 1), |
| 863 | + ) |
| 864 | + builder.output([sliced]) |
| 865 | + gm = builder.get_graph_module() |
| 866 | + |
| 867 | + result = PropagateSlice().call(gm) |
| 868 | + |
| 869 | + self.assertTrue(result.modified) |
| 870 | + |
| 871 | + slice_nodes = gm.graph.find_nodes( |
| 872 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 873 | + ) |
| 874 | + self.assertEqual(len(slice_nodes), 1) |
| 875 | + self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 4, 4]) |
| 876 | + |
| 877 | + def test_no_swap_when_multi_user(self) -> None: |
| 878 | + builder = GraphBuilder() |
| 879 | + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) |
| 880 | + quant = builder.call_operator( |
| 881 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 882 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 883 | + ) |
| 884 | + sliced = builder.call_operator( |
| 885 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 886 | + args=(quant, 0, 0, 4, 2), |
| 887 | + ) |
| 888 | + builder.output([sliced, quant]) |
| 889 | + gm = builder.get_graph_module() |
| 890 | + |
| 891 | + result = PropagateSlice().call(gm) |
| 892 | + |
| 893 | + self.assertFalse(result.modified) |
| 894 | + |
| 895 | + def test_no_swap_noop_slice(self) -> None: |
| 896 | + builder = GraphBuilder() |
| 897 | + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) |
| 898 | + quant = builder.call_operator( |
| 899 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 900 | + args=(x, 0.5, 0, 0, 255, torch.uint8), |
| 901 | + ) |
| 902 | + sliced = builder.call_operator( |
| 903 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 904 | + args=(quant, 0, 0, 4, 1), |
| 905 | + ) |
| 906 | + builder.output([sliced]) |
| 907 | + gm = builder.get_graph_module() |
| 908 | + |
| 909 | + result = PropagateSlice().call(gm) |
| 910 | + |
| 911 | + self.assertFalse(result.modified) |
| 912 | + |
| 913 | + def test_unsupported_parent_not_swapped(self) -> None: |
| 914 | + builder = GraphBuilder() |
| 915 | + x = builder.placeholder("x", torch.randn(4, 60, 1, 1)) |
| 916 | + relu = builder.call_operator( |
| 917 | + exir_ops.edge.aten.relu.default, |
| 918 | + args=(x,), |
| 919 | + ) |
| 920 | + sliced = builder.call_operator( |
| 921 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 922 | + args=(relu, 0, 0, 4, 2), |
| 923 | + ) |
| 924 | + builder.output([sliced]) |
| 925 | + gm = builder.get_graph_module() |
| 926 | + |
| 927 | + result = PropagateSlice().call(gm) |
| 928 | + |
| 929 | + self.assertFalse(result.modified) |
0 commit comments