|
6 | 6 | import operator |
7 | 7 | from typing import cast, ClassVar, Dict, Protocol, Tuple |
8 | 8 |
|
| 9 | +import executorch.backends.arm.tosa.dialect # noqa: F401 |
9 | 10 | import torch |
10 | 11 | from executorch.backends.arm._passes.fuse_constant_ops_pass import ( |
11 | 12 | ComputeConstantOpsAOTPass, |
|
15 | 16 | from executorch.backends.arm.test import common |
16 | 17 | from executorch.backends.arm.test.tester.arm_tester import ArmTester |
17 | 18 | from executorch.backends.arm.test.tester.test_pipeline import PassPipeline |
18 | | -from executorch.backends.arm.tosa import TosaSpecification |
| 19 | +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype |
| 20 | +from executorch.backends.arm.tosa.specification import ( |
| 21 | + TosaLoweringContext, |
| 22 | + TosaSpecification, |
| 23 | +) |
19 | 24 | from executorch.backends.test.harness.stages import StageType |
| 25 | +from executorch.backends.test.program_builder import ProgramBuilder |
| 26 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 27 | +from torch.export.graph_signature import InputKind |
20 | 28 |
|
21 | 29 | input_t = Tuple[torch.Tensor] # Input x |
22 | 30 | input_t2 = Tuple[torch.Tensor, torch.Tensor] |
@@ -270,3 +278,70 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: |
270 | 278 | for node in pass_result.graph_module.graph.nodes |
271 | 279 | if node.op == "placeholder" |
272 | 280 | ] == ["aten_cat_default_fused_const"] |
| 281 | + |
| 282 | + |
| 283 | +def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: |
| 284 | + class FakeTosaTarget: |
| 285 | + def __str__(self) -> str: |
| 286 | + return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default" |
| 287 | + |
| 288 | + assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget()) |
| 289 | + assert FuseConstantArgsPass._is_tosa_dialect_op( |
| 290 | + exir_ops.backend.tosa.GATHER.default |
| 291 | + ) |
| 292 | + assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) |
| 293 | + |
| 294 | + |
| 295 | +def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: |
| 296 | + graph = torch.fx.Graph() |
| 297 | + shape_node = graph.placeholder("shape") |
| 298 | + shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE |
| 299 | + |
| 300 | + assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2])) |
| 301 | + assert not FuseConstantArgsPass._arg_contains_symbolic_shape( |
| 302 | + ([1, 2], {"pad": (0, 0)}) |
| 303 | + ) |
| 304 | + |
| 305 | + |
| 306 | +def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: |
| 307 | + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): |
| 308 | + builder = ProgramBuilder() |
| 309 | + values = builder.placeholder( |
| 310 | + "values", |
| 311 | + torch.randn(1, 4, 3), |
| 312 | + input_kind=InputKind.CONSTANT_TENSOR, |
| 313 | + ) |
| 314 | + indices = builder.placeholder( |
| 315 | + "indices", |
| 316 | + torch.tensor([[0, 2]], dtype=torch.int32), |
| 317 | + input_kind=InputKind.CONSTANT_TENSOR, |
| 318 | + ) |
| 319 | + gather = builder.call_operator( |
| 320 | + exir_ops.backend.tosa.GATHER.default, |
| 321 | + (values, indices), |
| 322 | + ) |
| 323 | + builder.output([gather]) |
| 324 | + |
| 325 | + exported_program = builder.get_program() |
| 326 | + graph_module = exported_program.graph_module |
| 327 | + |
| 328 | + with caplog.at_level("WARNING"): |
| 329 | + FuseConstantArgsPass(exported_program)(graph_module) |
| 330 | + |
| 331 | + warning_messages = [ |
| 332 | + record.getMessage() |
| 333 | + for record in caplog.records |
| 334 | + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" |
| 335 | + ] |
| 336 | + assert not any( |
| 337 | + "Failed to fuse constant op" in message and "GATHER" in message |
| 338 | + for message in warning_messages |
| 339 | + ) |
| 340 | + assert ( |
| 341 | + sum( |
| 342 | + node.op == "call_function" |
| 343 | + and node.target == exir_ops.backend.tosa.GATHER.default |
| 344 | + for node in graph_module.graph.nodes |
| 345 | + ) |
| 346 | + == 1 |
| 347 | + ) |
0 commit comments