Skip to content

Commit acf1ad9

Browse files
authored
Handle rank-changing views in RemovePermutesAroundElementwiseOps (pytorch#19538)
Differential Revision: D104775244 Pull Request resolved: pytorch#19538
1 parent 1e76bb3 commit acf1ad9

5 files changed

Lines changed: 764 additions & 97 deletions

File tree

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212

1313
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
14-
permutable_ops = {
15-
*RemovePermutesAroundElementwiseOps.permutable_ops,
16-
*TableOps.unary_table_ops.keys(),
17-
*TableOps.special_table_ops,
18-
exir_ops.backend.tosa.RESCALE.default,
19-
exir_ops.backend.tosa.TABLE.default,
20-
}
14+
def __init__(self) -> None:
15+
super().__init__(
16+
extra_permutable_ops={
17+
*TableOps.unary_table_ops.keys(),
18+
*TableOps.special_table_ops,
19+
exir_ops.backend.tosa.RESCALE.default,
20+
exir_ops.backend.tosa.TABLE.default,
21+
}
22+
)
2123

2224
def permute_subgraph(self, subgraph):
2325
# Original function will always permute constant nodes which is wrong for table ops

backends/cadence/aot/remove_ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -603,16 +603,16 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
603603

604604
@register_cadence_pass(CadencePassAttribute(opt_level=2))
605605
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
606-
permutable_ops: set[EdgeOpOverload] = (
607-
_SharedRemovePermutesAroundElementwiseOps.permutable_ops
608-
| {
609-
exir_ops.edge.cadence.quantize_per_tensor.default,
610-
exir_ops.edge.cadence.dequantize_per_tensor.default,
611-
exir_ops.edge.cadence.quantized_relu.per_tensor,
612-
exir_ops.edge.cadence.requantize.per_tensor,
613-
exir_ops.edge.cadence.quantized_add.per_tensor,
614-
}
615-
)
606+
def __init__(self) -> None:
607+
super().__init__(
608+
extra_permutable_ops={
609+
exir_ops.edge.cadence.quantize_per_tensor.default,
610+
exir_ops.edge.cadence.dequantize_per_tensor.default,
611+
exir_ops.edge.cadence.quantized_relu.per_tensor,
612+
exir_ops.edge.cadence.requantize.per_tensor,
613+
exir_ops.edge.cadence.quantized_add.per_tensor,
614+
}
615+
)
616616

617617

618618
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)