33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import sys
67from typing import Tuple
78
89import torch
1718
1819input_t = Tuple [torch .Tensor ]
1920
20- exir_ops = [
21+ ops_expected_absent_after_lowering = [
2122 "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
2223 "executorch_exir_dialects_edge__ops_aten_convolution_default" ,
2324 "executorch_exir_dialects_edge__ops_aten_layer_norm_default" ,
2728 "executorch_exir_dialects_edge__ops_aten_softmax_int" ,
2829]
2930
31+ # TODO/MLETORCH-2163: Investigate Swin2SR delegation gaps around index/view
32+ # in FP and Q/DQ, clamp, and expand_copy in INT.
33+ swin2sr_fp_lowered_outer_graph_ops = {
34+ "torch.ops.higher_order.executorch_call_delegate" : 2 ,
35+ "executorch_exir_dialects_edge__ops_aten_index_Tensor" : 2 ,
36+ "executorch_exir_dialects_edge__ops_aten_view_copy_default" : 2 ,
37+ }
38+ swin2sr_int_lowered_outer_graph_ops = {
39+ "torch.ops.higher_order.executorch_call_delegate" : 3 ,
40+ "executorch_exir_dialects_edge__ops_aten_clamp_default" : 4 ,
41+ "executorch_exir_dialects_edge__ops_aten_expand_copy_default" : 4 ,
42+ "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" : 5 ,
43+ "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" : 6 ,
44+ }
45+
3046
3147class TinySwin2SR (torch .nn .Module ):
3248 def __init__ (self ):
@@ -62,12 +78,10 @@ def test_swin2sr_tosa_FP():
6278 model ,
6379 model_inputs ,
6480 aten_op = [],
65- exir_op = exir_ops ,
81+ exir_op = ops_expected_absent_after_lowering ,
6682 use_to_edge_transform_and_lower = True ,
6783 )
68- pipeline .pop_stage ("check_count.exir" )
69- # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
70- pipeline .pop_stage ("run_method_and_compare_outputs" )
84+ pipeline .change_args ("check_count.exir" , swin2sr_fp_lowered_outer_graph_ops )
7185 pipeline .run ()
7286
7387
@@ -77,12 +91,10 @@ def test_swin2sr_tosa_INT():
7791 model ,
7892 model_inputs ,
7993 aten_op = [],
80- exir_op = exir_ops ,
94+ exir_op = ops_expected_absent_after_lowering ,
8195 use_to_edge_transform_and_lower = True ,
8296 )
83- pipeline .pop_stage ("check_count.exir" )
84- # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
85- pipeline .pop_stage ("run_method_and_compare_outputs" )
97+ pipeline .change_args ("check_count.exir" , swin2sr_int_lowered_outer_graph_ops )
8698 pipeline .run ()
8799
88100
@@ -93,13 +105,12 @@ def test_swin2sr_vgf_quant():
93105 model ,
94106 model_inputs ,
95107 aten_op = [],
96- exir_op = exir_ops ,
108+ exir_op = ops_expected_absent_after_lowering ,
97109 use_to_edge_transform_and_lower = True ,
98110 quantize = True ,
111+ run_on_vulkan_runtime = sys .platform == "linux" ,
99112 )
100- pipeline .pop_stage ("check_count.exir" )
101- # TODO: MLETORCH-2134 re-enable once Swin2SR runs on the TOSA ref model.
102- pipeline .pop_stage ("run_method_and_compare_outputs" )
113+ pipeline .change_args ("check_count.exir" , swin2sr_int_lowered_outer_graph_ops )
103114 pipeline .run ()
104115
105116
@@ -110,9 +121,9 @@ def test_swin2sr_vgf_no_quant():
110121 model ,
111122 model_inputs ,
112123 aten_op = [],
113- exir_op = exir_ops ,
124+ exir_op = ops_expected_absent_after_lowering ,
114125 use_to_edge_transform_and_lower = True ,
115126 quantize = False ,
116127 )
117- pipeline .pop_stage ("check_count.exir" )
128+ pipeline .change_args ("check_count.exir" , swin2sr_fp_lowered_outer_graph_ops )
118129 pipeline .run ()
0 commit comments