Skip to content

Commit 2c9c9dd

Browse files
authored
Arm backend: Enable Swin2SR TOSA ref tests (#19771)
Summary: - Enable Swin2SR FP and INT TOSA pipelines to run through the reference model. - Keep quantized VGF runtime execution Linux-only until Darwin VKML validation is available. - Record current Swin2SR partition boundaries and track delegation gaps in MLETORCH-2163. Test Plan: - lintrunner on test_swin2sr_arm.py - backends/arm/scripts/pre-push cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Usamah Zaheer <usamah.zaheer@arm.com>
1 parent 77df9b7 commit 2c9c9dd

1 file changed

Lines changed: 26 additions & 15 deletions

File tree

backends/arm/test/models/test_swin2sr_arm.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import sys
67
from typing import Tuple
78

89
import torch
@@ -17,7 +18,7 @@
1718

1819
input_t = Tuple[torch.Tensor]
1920

20-
exir_ops = [
21+
ops_expected_absent_after_lowering = [
2122
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
2223
"executorch_exir_dialects_edge__ops_aten_convolution_default",
2324
"executorch_exir_dialects_edge__ops_aten_layer_norm_default",
@@ -27,6 +28,21 @@
2728
"executorch_exir_dialects_edge__ops_aten_softmax_int",
2829
]
2930

31+
# TODO/MLETORCH-2163: Investigate Swin2SR delegation gaps around index/view
32+
# in FP and Q/DQ, clamp, and expand_copy in INT.
33+
swin2sr_fp_lowered_outer_graph_ops = {
34+
"torch.ops.higher_order.executorch_call_delegate": 2,
35+
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 2,
36+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
37+
}
38+
swin2sr_int_lowered_outer_graph_ops = {
39+
"torch.ops.higher_order.executorch_call_delegate": 3,
40+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 4,
41+
"executorch_exir_dialects_edge__ops_aten_expand_copy_default": 4,
42+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5,
43+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6,
44+
}
45+
3046

3147
class TinySwin2SR(torch.nn.Module):
3248
def __init__(self):
@@ -62,12 +78,10 @@ def test_swin2sr_tosa_FP():
6278
model,
6379
model_inputs,
6480
aten_op=[],
65-
exir_op=exir_ops,
81+
exir_op=ops_expected_absent_after_lowering,
6682
use_to_edge_transform_and_lower=True,
6783
)
68-
pipeline.pop_stage("check_count.exir")
69-
# TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
70-
pipeline.pop_stage("run_method_and_compare_outputs")
84+
pipeline.change_args("check_count.exir", swin2sr_fp_lowered_outer_graph_ops)
7185
pipeline.run()
7286

7387

@@ -77,12 +91,10 @@ def test_swin2sr_tosa_INT():
7791
model,
7892
model_inputs,
7993
aten_op=[],
80-
exir_op=exir_ops,
94+
exir_op=ops_expected_absent_after_lowering,
8195
use_to_edge_transform_and_lower=True,
8296
)
83-
pipeline.pop_stage("check_count.exir")
84-
# TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
85-
pipeline.pop_stage("run_method_and_compare_outputs")
97+
pipeline.change_args("check_count.exir", swin2sr_int_lowered_outer_graph_ops)
8698
pipeline.run()
8799

88100

@@ -93,13 +105,12 @@ def test_swin2sr_vgf_quant():
93105
model,
94106
model_inputs,
95107
aten_op=[],
96-
exir_op=exir_ops,
108+
exir_op=ops_expected_absent_after_lowering,
97109
use_to_edge_transform_and_lower=True,
98110
quantize=True,
111+
run_on_vulkan_runtime=sys.platform == "linux",
99112
)
100-
pipeline.pop_stage("check_count.exir")
101-
# TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
102-
pipeline.pop_stage("run_method_and_compare_outputs")
113+
pipeline.change_args("check_count.exir", swin2sr_int_lowered_outer_graph_ops)
103114
pipeline.run()
104115

105116

@@ -110,9 +121,9 @@ def test_swin2sr_vgf_no_quant():
110121
model,
111122
model_inputs,
112123
aten_op=[],
113-
exir_op=exir_ops,
124+
exir_op=ops_expected_absent_after_lowering,
114125
use_to_edge_transform_and_lower=True,
115126
quantize=False,
116127
)
117-
pipeline.pop_stage("check_count.exir")
128+
pipeline.change_args("check_count.exir", swin2sr_fp_lowered_outer_graph_ops)
118129
pipeline.run()

0 commit comments

Comments
 (0)