|
8 | 8 | import torch |
9 | 9 | from executorch.backends.arm._passes import ( |
10 | 10 | AnnotateOutputDimOrderPass, |
| 11 | + EnsureUniqueOutputNodesPass, |
| 12 | + FuseEqualPlaceholdersPass, |
11 | 13 | ToTosaMemoryFormatPass, |
12 | 14 | ) |
13 | 15 |
|
14 | 16 | from executorch.backends.arm.test import common |
15 | 17 | from executorch.backends.arm.test.tester.test_pipeline import ( |
16 | 18 | PassPipeline, |
| 19 | + TosaPipelineFP, |
17 | 20 | TosaPipelineINT, |
18 | 21 | ) |
19 | 22 | from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass |
| 23 | +from executorch.exir.dialects._ops import ops as exir_ops |
20 | 24 |
|
21 | 25 | input_t = Tuple[torch.Tensor] # Input x |
22 | 26 |
|
@@ -177,6 +181,26 @@ def get_inputs(self) -> input_t: |
177 | 181 | return (torch.rand(4, 4, 4, 4),) |
178 | 182 |
|
179 | 183 |
|
| 184 | +class DuplicateConstantOutputs(torch.nn.Module): |
| 185 | + def __init__(self) -> None: |
| 186 | + super().__init__() |
| 187 | + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) |
| 188 | + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) |
| 189 | + |
| 190 | + def forward(self, x: torch.Tensor): |
| 191 | + return self.grid0, self.grid1, x |
| 192 | + |
| 193 | + |
| 194 | +class DuplicateConstantOutputsWithAdd(torch.nn.Module): |
| 195 | + def __init__(self) -> None: |
| 196 | + super().__init__() |
| 197 | + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) |
| 198 | + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) |
| 199 | + |
| 200 | + def forward(self, x: torch.Tensor): |
| 201 | + return self.grid0, self.grid1, x + x |
| 202 | + |
| 203 | + |
180 | 204 | modules: Dict[str, ModuleMetadata] = { |
181 | 205 | "no_nhwc": NoNHWC(), |
182 | 206 | "parallel_clusters": ParallelClusters(), |
@@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No |
209 | 233 | module_nn = cast(torch.nn.Module, module) |
210 | 234 | pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) |
211 | 235 | pipeline.run() |
| 236 | + |
| 237 | + |
| 238 | +def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None: |
| 239 | + pipeline = PassPipeline[input_t]( |
| 240 | + DuplicateConstantOutputs(), |
| 241 | + (torch.rand(1, 2, 32, 32),), |
| 242 | + quantize=False, |
| 243 | + pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], |
| 244 | + passes_with_exported_program=[ |
| 245 | + FuseEqualPlaceholdersPass, |
| 246 | + ToTosaMemoryFormatPass, |
| 247 | + EnsureUniqueOutputNodesPass, |
| 248 | + ], |
| 249 | + ) |
| 250 | + pipeline.pop_stage("run_method_and_compare_outputs") |
| 251 | + pipeline.run() |
| 252 | + |
| 253 | + graph_module = pipeline.tester.get_artifact().exported_program().graph_module |
| 254 | + output_node = graph_module.graph.output_node() |
| 255 | + outputs = list(output_node.args[0]) |
| 256 | + |
| 257 | + assert outputs[0] is not outputs[1] |
| 258 | + assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default |
| 259 | + assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default |
| 260 | + assert outputs[0].args[0] is outputs[1].args[0] |
| 261 | + |
| 262 | + |
| 263 | +def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None: |
| 264 | + pipeline = TosaPipelineFP[input_t]( |
| 265 | + DuplicateConstantOutputsWithAdd(), |
| 266 | + (torch.rand(1, 2, 32, 32),), |
| 267 | + [], |
| 268 | + [], |
| 269 | + ) |
| 270 | + pipeline.run() |
0 commit comments