|
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +from types import SimpleNamespace |
| 7 | +from typing import cast |
| 8 | + |
6 | 9 | import torch |
7 | 10 | from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import ( |
8 | 11 | RemovePermutesAroundElementwiseTosaOps, |
|
11 | 14 | TosaLoweringContext, |
12 | 15 | TosaSpecification, |
13 | 16 | ) |
| 17 | +from executorch.exir import ExportedProgram |
14 | 18 | from executorch.exir.dialects._ops import ops as exir_ops |
15 | 19 |
|
16 | 20 | TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT") |
| 21 | +TOSA_FP_SPEC = TosaSpecification.create_from_string("TOSA-1.0+FP") |
17 | 22 | PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default |
18 | 23 | RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default |
19 | | -TABLE_TARGET = exir_ops.backend.tosa.TABLE.default |
| 24 | +MUL_TARGET = exir_ops.edge.aten.mul.Tensor |
| 25 | +ADD_TARGET = exir_ops.edge.aten.add.Tensor |
| 26 | +ERF_TARGET = exir_ops.edge.aten.erf.default |
| 27 | + |
| 28 | + |
| 29 | +def _fake_exported_program() -> ExportedProgram: |
| 30 | + return cast( |
| 31 | + ExportedProgram, |
| 32 | + SimpleNamespace( |
| 33 | + graph_signature=SimpleNamespace( |
| 34 | + inputs_to_buffers={}, |
| 35 | + inputs_to_lifted_tensor_constants={}, |
| 36 | + inputs_to_parameters={}, |
| 37 | + ) |
| 38 | + ), |
| 39 | + ) |
20 | 40 |
|
21 | 41 |
|
22 | 42 | def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int: |
@@ -52,8 +72,125 @@ def test_remove_permutes_around_rescale_tosa_INT() -> None: |
52 | 72 | graph_module = torch.fx.GraphModule({}, graph) |
53 | 73 |
|
54 | 74 | with TosaLoweringContext(TOSA_INT_SPEC): |
55 | | - result = RemovePermutesAroundElementwiseTosaOps().call(graph_module) |
| 75 | + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( |
| 76 | + graph_module |
| 77 | + ) |
56 | 78 |
|
57 | 79 | assert result.modified |
58 | 80 | assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0 |
59 | 81 | assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1 |
| 82 | + |
| 83 | + |
| 84 | +def test_remove_permutes_around_gelu_with_folded_scalar_constants_tosa_FP() -> None: |
| 85 | + graph = torch.fx.Graph() |
| 86 | + x = graph.placeholder("x") |
| 87 | + x.meta["val"] = torch.randn(1, 2, 3, 4) |
| 88 | + |
| 89 | + scalar_constants = [] |
| 90 | + for i in range(3): |
| 91 | + const = graph.placeholder(f"c_scalar_{i}") |
| 92 | + const.meta["val"] = torch.randn(1, 1, 1, 1) |
| 93 | + scalar_constants.append(const) |
| 94 | + |
| 95 | + permute_in = graph.create_node( |
| 96 | + "call_function", |
| 97 | + PERMUTE_TARGET, |
| 98 | + args=(x, [0, 2, 3, 1]), |
| 99 | + ) |
| 100 | + permute_in.meta["val"] = torch.randn(1, 3, 4, 2) |
| 101 | + mul_0 = graph.create_node( |
| 102 | + "call_function", |
| 103 | + MUL_TARGET, |
| 104 | + args=(permute_in, scalar_constants[0]), |
| 105 | + ) |
| 106 | + mul_0.meta["val"] = torch.randn(1, 3, 4, 2) |
| 107 | + erf = graph.create_node("call_function", ERF_TARGET, args=(mul_0,)) |
| 108 | + erf.meta["val"] = torch.randn(1, 3, 4, 2) |
| 109 | + add = graph.create_node( |
| 110 | + "call_function", |
| 111 | + ADD_TARGET, |
| 112 | + args=(erf, scalar_constants[1]), |
| 113 | + ) |
| 114 | + add.meta["val"] = torch.randn(1, 3, 4, 2) |
| 115 | + mul_1 = graph.create_node( |
| 116 | + "call_function", |
| 117 | + MUL_TARGET, |
| 118 | + args=(add, scalar_constants[2]), |
| 119 | + ) |
| 120 | + mul_1.meta["val"] = torch.randn(1, 3, 4, 2) |
| 121 | + mul_2 = graph.create_node( |
| 122 | + "call_function", |
| 123 | + MUL_TARGET, |
| 124 | + args=(permute_in, mul_1), |
| 125 | + ) |
| 126 | + mul_2.meta["val"] = torch.randn(1, 3, 4, 2) |
| 127 | + permute_out = graph.create_node( |
| 128 | + "call_function", |
| 129 | + PERMUTE_TARGET, |
| 130 | + args=(mul_2, [0, 3, 1, 2]), |
| 131 | + ) |
| 132 | + permute_out.meta["val"] = torch.randn(1, 2, 3, 4) |
| 133 | + graph.output(permute_out) |
| 134 | + |
| 135 | + graph_module = torch.fx.GraphModule({}, graph) |
| 136 | + |
| 137 | + with TosaLoweringContext(TOSA_FP_SPEC): |
| 138 | + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( |
| 139 | + graph_module |
| 140 | + ) |
| 141 | + |
| 142 | + assert result.modified |
| 143 | + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 3 |
| 144 | + assert _count_nodes(result.graph_module, ERF_TARGET) == 1 |
| 145 | + |
| 146 | + |
| 147 | +def test_remove_permutes_skips_stale_shared_boundary_subgraph_tosa_FP() -> None: |
| 148 | + graph = torch.fx.Graph() |
| 149 | + x = graph.placeholder("x") |
| 150 | + x.meta["val"] = torch.randn(1, 16, 16, 8) |
| 151 | + |
| 152 | + channel_const = graph.placeholder("p_layer_norm_weight") |
| 153 | + channel_const.meta["val"] = torch.randn(1, 1, 1, 8) |
| 154 | + |
| 155 | + permute_in = graph.create_node( |
| 156 | + "call_function", |
| 157 | + PERMUTE_TARGET, |
| 158 | + args=(x, [0, 3, 1, 2]), |
| 159 | + ) |
| 160 | + permute_in.meta["val"] = torch.randn(1, 8, 16, 16) |
| 161 | + first_mul = graph.create_node( |
| 162 | + "call_function", |
| 163 | + MUL_TARGET, |
| 164 | + args=(permute_in, permute_in), |
| 165 | + ) |
| 166 | + first_mul.meta["val"] = torch.randn(1, 8, 16, 16) |
| 167 | + shared_permute = graph.create_node( |
| 168 | + "call_function", |
| 169 | + PERMUTE_TARGET, |
| 170 | + args=(first_mul, [0, 2, 3, 1]), |
| 171 | + ) |
| 172 | + shared_permute.meta["val"] = torch.randn(1, 16, 16, 8) |
| 173 | + second_mul = graph.create_node( |
| 174 | + "call_function", |
| 175 | + MUL_TARGET, |
| 176 | + args=(shared_permute, channel_const), |
| 177 | + ) |
| 178 | + second_mul.meta["val"] = torch.randn(1, 16, 16, 8) |
| 179 | + permute_out = graph.create_node( |
| 180 | + "call_function", |
| 181 | + PERMUTE_TARGET, |
| 182 | + args=(second_mul, [0, 3, 1, 2]), |
| 183 | + ) |
| 184 | + permute_out.meta["val"] = torch.randn(1, 8, 16, 16) |
| 185 | + graph.output(permute_out) |
| 186 | + |
| 187 | + graph_module = torch.fx.GraphModule({}, graph) |
| 188 | + |
| 189 | + with TosaLoweringContext(TOSA_FP_SPEC): |
| 190 | + result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( |
| 191 | + graph_module |
| 192 | + ) |
| 193 | + |
| 194 | + assert result.modified |
| 195 | + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 1 |
| 196 | + assert second_mul.args[1] is channel_const |
0 commit comments