1313 SupportedTOSAOperatorCheck ,
1414)
1515from executorch .backends .arm .tosa import TosaSpecification
16+ from executorch .backends .arm .tosa .resize_utils import get_tosa_resize_validation_error
1617from executorch .exir .dialects ._ops import ops as exir_ops
1718
1819
20+ def _is_upsample_node_tosa_supported (
21+ support_check : SupportedTOSAOperatorCheck ,
22+ node : fx .Node ,
23+ tosa_spec : TosaSpecification ,
24+ * ,
25+ align_corners : bool ,
26+ ) -> bool :
27+ input_node = ensure_type (fx .Node , node .args [0 ])
28+ input_size_yx = get_first_fake_tensor (input_node ).shape [2 :]
29+ output_size_yx = get_first_fake_tensor (node ).shape [2 :]
30+
31+ try :
32+ scale_y_n , scale_y_d , offset_y , border_y = (
33+ RewriteUpsamplePass .get_resize_parameters_1d (
34+ input_size_yx [0 ], output_size_yx [0 ], align_corners
35+ )
36+ )
37+ scale_x_n , scale_x_d , offset_x , border_x = (
38+ RewriteUpsamplePass .get_resize_parameters_1d (
39+ input_size_yx [1 ], output_size_yx [1 ], align_corners
40+ )
41+ )
42+ except RuntimeError as err :
43+ support_check .reporter .report_reject (node , str (err ))
44+ return False
45+
46+ # Validate the exact TOSA RESIZE parameters that RewriteUpsamplePass will
47+ # emit so support checks and fake-op validation reject the same cases.
48+ validation_error = get_tosa_resize_validation_error (
49+ input_hw = input_size_yx ,
50+ output_hw = output_size_yx ,
51+ scale = [scale_y_n , scale_y_d , scale_x_n , scale_x_d ],
52+ offset = [offset_y , offset_x ],
53+ border = [border_y , border_x ],
54+ tosa_spec = tosa_spec ,
55+ )
56+ if validation_error is not None :
57+ support_check .reporter .report_reject (node , validation_error )
58+ return False
59+
60+ return True
61+
62+
1963@register_tosa_support_check
2064class UpsampleNearest2dSupported (SupportedTOSAOperatorCheck ):
2165 """Provide the explicit TOSA support gate for nearest upsample."""
2266
2367 targets = [exir_ops .edge .aten .upsample_nearest2d .vec ]
2468
2569 def is_node_tosa_supported (
26- self , _node : fx .Node , _tosa_spec : TosaSpecification
70+ self , node : fx .Node , tosa_spec : TosaSpecification
2771 ) -> bool : # type: ignore[override, misc]
28- return True
72+ return _is_upsample_node_tosa_supported (
73+ self , node , tosa_spec , align_corners = False
74+ )
2975
3076
3177@register_tosa_support_check
@@ -37,33 +83,9 @@ class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck):
3783 targets = [exir_ops .edge .aten .upsample_bilinear2d .vec ]
3884
3985 def is_node_tosa_supported (
40- self , node : fx .Node , _tosa_spec : TosaSpecification
86+ self , node : fx .Node , tosa_spec : TosaSpecification
4187 ) -> bool : # type: ignore[override, misc]
42- input_node = ensure_type (fx .Node , node .args [0 ])
4388 align_corners = ensure_type (bool , node .args [2 ])
44- input_size_yx = get_first_fake_tensor (input_node ).shape [2 :]
45- output_size_yx = get_first_fake_tensor (node ).shape [2 :]
46-
47- try :
48- scale_y_n , scale_y_d , _ , _ = RewriteUpsamplePass .get_resize_parameters_1d (
49- input_size_yx [0 ], output_size_yx [0 ], align_corners
50- )
51- scale_x_n , scale_x_d , _ , _ = RewriteUpsamplePass .get_resize_parameters_1d (
52- input_size_yx [1 ], output_size_yx [1 ], align_corners
53- )
54- except RuntimeError as err :
55- self .reporter .report_reject (node , str (err ))
56- return False
57-
58- # get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for
59- # each spatial dimension. For align_corners=False, this is the effective
60- # output_size / input_size ratio, so the 1/16 boundary is checked
61- # directly in the same representation that RESIZE lowering will use.
62- if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n :
63- self .reporter .report_reject (
64- node ,
65- "Bilinear RESIZE downscale must be strictly greater than 1/16" ,
66- )
67- return False
68-
69- return True
89+ return _is_upsample_node_tosa_supported (
90+ self , node , tosa_spec , align_corners = align_corners
91+ )
0 commit comments