Skip to content

Commit 3447d08

Browse files
authored
Quantize moveaxis/movedim so they delegate to Ethos-U (#20314)
Differential Revision: D108478011 Pull Request resolved: #20453
1 parent 3169302 commit 3447d08

3 files changed

Lines changed: 62 additions & 0 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,16 @@ def _get_fixed_qparams_qspec(
631631
if _transpose_dimname is not None:
632632
_one_to_one_shared_input_qspec.add(_transpose_dimname)
633633

634+
for _op in (
635+
getattr(torch.ops.aten.moveaxis, "int", None),
636+
getattr(torch.ops.aten.moveaxis, "intlist", None),
637+
getattr(torch.ops.aten.movedim, "int", None),
638+
getattr(torch.ops.aten.movedim, "intlist", None),
639+
):
640+
if _op is not None:
641+
_one_to_one_shared_input_qspec.add(_op)
642+
643+
634644
_one_to_one_shared_input_or_input_act_qspec: set[OpOverload] = {
635645
torch.ops.aten.alias.default,
636646
torch.ops.aten.clone.default,

backends/arm/test/ops/test_permute.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def forward(self, x):
7878
return torch.permute(x, self.dims)
7979

8080

81+
class SimpleMoveAxis(torch.nn.Module):
82+
83+
def forward(self, x):
84+
return torch.moveaxis(x, 1, -1)
85+
86+
8187
@common.parametrize(
8288
"test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16
8389
)
@@ -118,6 +124,17 @@ def test_permute_u55_INT(test_data):
118124
pipeline.run()
119125

120126

127+
def test_moveaxis_u55_INT():
128+
pipeline = EthosU55PipelineINT[input_t1](
129+
SimpleMoveAxis(),
130+
(torch.rand(1, 4, 5, 6),),
131+
"torch.ops.aten.moveaxis.int",
132+
exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default",
133+
run_on_fvp=False,
134+
)
135+
pipeline.run()
136+
137+
121138
@common.parametrize("test_data", test_data_suite_u55_reject)
122139
def test_permute_u55_INT_not_delegated(test_data: torch.Tensor):
123140
test_data, dims = test_data()

backends/arm/test/quantizer/test_generic_annotater.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,41 @@ def test_transpose_tosa_INT():
8989
)
9090

9191

92+
def test_moveaxis_movedim_tosa_INT():
93+
check_annotation(
94+
SingleOpModel(
95+
torch.moveaxis,
96+
(torch.randn(2, 3, 4),),
97+
source=1,
98+
destination=-1,
99+
),
100+
)
101+
check_annotation(
102+
SingleOpModel(
103+
torch.moveaxis,
104+
(torch.randn(2, 3, 4),),
105+
source=(0, 1),
106+
destination=(-1, -2),
107+
),
108+
)
109+
check_annotation(
110+
SingleOpModel(
111+
torch.movedim,
112+
(torch.randn(2, 3, 4),),
113+
source=1,
114+
destination=-1,
115+
),
116+
)
117+
check_annotation(
118+
SingleOpModel(
119+
torch.movedim,
120+
(torch.randn(2, 3, 4),),
121+
source=(0, 1),
122+
destination=(-1, -2),
123+
),
124+
)
125+
126+
92127
def test_tile_tosa_INT():
93128
check_annotation(
94129
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),

0 commit comments

Comments
 (0)