44# LICENSE file in the root directory of this source tree.
55
66import random
7- from pathlib import Path
8- from typing import Any
7+ from dataclasses import dataclass
8+ from typing import Any , Tuple
99
1010import torch
11- import torch . nn as nn
11+
1212from executorch .backends .arm .test import common
1313from executorch .backends .arm .test .tester .test_pipeline import TosaPipelineINT
14- from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
14+
15+
16+ InputT = Tuple [Any , ...]
1517
1618
1719class HighRankPermuteViewModel (torch .nn .Module ):
1820 def __init__ (self , ops : list [tuple [str , Any ]]):
1921 super ().__init__ ()
2022 self .ops = ops
21- self .block = nn .Sequential (
22- nn .Conv2d (
23+ self .block = torch . nn .Sequential (
24+ torch . nn .Conv2d (
2325 in_channels = 3 ,
2426 out_channels = 64 ,
2527 kernel_size = 3 ,
2628 stride = 2 ,
2729 padding = 1 ,
2830 ),
29- nn .ReLU (),
31+ torch . nn .ReLU (),
3032 )
3133
3234 def forward (self , x ):
@@ -41,6 +43,13 @@ def forward(self, x):
4143 return x
4244
4345
46+ @dataclass (frozen = True )
47+ class TransposeInvariantCase :
48+ module : torch .nn .Module
49+ inputs : InputT
50+ expected_transposes : int
51+
52+
4453def _random_non_identity_permutation (
4554 rng : random .Random , rank : int
4655) -> tuple [int , ...] | None :
@@ -130,7 +139,6 @@ def _generate_chain(
130139 shape = new_shape
131140 break
132141
133- # Ensure each case has at least one rank>4 permute.
134142 while len (shape ) <= 4 :
135143 new_shape = _reshape_add_singleton (rng , shape )
136144 if new_shape is None :
@@ -146,108 +154,33 @@ def _generate_chain(
146154 return ops
147155
148156
149- def _build_cases () -> dict [str , HighRankPermuteViewModel ]:
157+ def _build_high_rank_permute_cases () -> dict [str , TransposeInvariantCase ]:
150158 rng = random .Random (
151159 20260225
152160 ) # nosec B311: deterministic RNG for test case generation
153- start_shape = [1 , 16 , 16 , 64 ] # conv output from input 1x3x32x32 after NHWC permute
154- cases : dict [str , HighRankPermuteViewModel ] = {}
161+ start_shape = [1 , 16 , 16 , 64 ]
162+ expected_transpose_counts = [6 , 11 , 10 , 10 , 7 , 7 , 10 , 10 , 8 , 10 ]
163+ cases : dict [str , TransposeInvariantCase ] = {}
155164 for idx in range (10 ):
156165 ops = _generate_chain (rng , start_shape , steps = 8 )
157- cases [f"fuzz_case_{ idx } " ] = HighRankPermuteViewModel (ops )
166+ cases [f"high_rank_permute_fuzz_case_{ idx } " ] = TransposeInvariantCase (
167+ module = HighRankPermuteViewModel (ops ).eval (),
168+ inputs = (torch .randn (1 , 3 , 32 , 32 ),),
169+ expected_transposes = expected_transpose_counts [idx ],
170+ )
158171 return cases
159172
160173
161- def _run_model (model : torch .nn .Module , out_dir : str ) -> Path :
162- sample = torch .randn (1 , 3 , 32 , 32 )
163- pipeline = TosaPipelineINT [tuple [torch .Tensor ]](
164- model .eval (),
165- (sample ,),
174+ @common .parametrize ("case" , _build_high_rank_permute_cases ())
175+ def test_transpose_invariants_tosa_INT_high_rank_permute_view (
176+ case : TransposeInvariantCase ,
177+ ) -> None :
178+ pipeline = TosaPipelineINT [InputT ](
179+ case .module ,
180+ case .inputs ,
166181 aten_op = [],
167182 exir_op = [],
168183 run_on_tosa_ref_model = False ,
169- custom_path = out_dir ,
170- tosa_debug_mode = TosaCompileSpec .DebugMode .JSON ,
171- tosa_extensions = ["int16" , "int4" , "cf" ],
172184 )
185+ pipeline .count_tosa_ops ({"TRANSPOSE" : case .expected_transposes })
173186 pipeline .run ()
174-
175- tosa_files = sorted (Path (out_dir ).glob ("*.tosa" ))
176- assert tosa_files , f"No TOSA artifacts found in { out_dir } "
177- return tosa_files [0 ]
178-
179-
180- def _assert_transpose_invariants (tosa_path : Path ) -> int :
181- import tosa .Op as Op # type: ignore[import-not-found,import-untyped]
182- from tosa .TosaGraph import ( # type: ignore[import-not-found,import-untyped]
183- TosaGraph ,
184- )
185- from tosa .TransposeAttribute import ( # type: ignore[import-not-found,import-untyped]
186- TransposeAttribute ,
187- )
188-
189- graph = TosaGraph .GetRootAs (tosa_path .read_bytes (), 0 )
190- block = graph .Regions (0 ).Blocks (0 )
191-
192- shape_by_name = {
193- block .Tensors (i ).Name ().decode (): list (block .Tensors (i ).ShapeAsNumpy ())
194- for i in range (block .TensorsLength ())
195- }
196-
197- op_enum = Op .Op ()
198- op_value_to_name = {
199- getattr (op_enum , name ): name for name in dir (op_enum ) if name .isupper ()
200- }
201-
202- high_rank_transpose_count = 0
203- for i in range (block .OperatorsLength ()):
204- op = block .Operators (i )
205- if op_value_to_name .get (op .Op ()) != "TRANSPOSE" :
206- continue
207-
208- inputs = [op .Inputs (j ).decode () for j in range (op .InputsLength ())]
209- outputs = [op .Outputs (j ).decode () for j in range (op .OutputsLength ())]
210- assert len (inputs ) == 1 and len (outputs ) == 1 , (
211- f"Unexpected TRANSPOSE arity at op #{ i } : "
212- f"{ len (inputs )} inputs, { len (outputs )} outputs"
213- )
214-
215- attr_tbl = op .Attribute ()
216- transpose_attr = TransposeAttribute ()
217- transpose_attr .Init (attr_tbl .Bytes , attr_tbl .Pos )
218- perms = [int (perm ) for perm in transpose_attr .PermsAsNumpy ()]
219-
220- in_shape = [int (v ) for v in shape_by_name [inputs [0 ]]]
221- out_shape = [int (v ) for v in shape_by_name [outputs [0 ]]]
222-
223- rank = len (in_shape )
224- assert (
225- len (perms ) == rank
226- ), f"Invalid TRANSPOSE rank at op #{ i } : len(perms)={ len (perms )} rank={ rank } "
227- assert sorted (perms ) == list (
228- range (rank )
229- ), f"Invalid TRANSPOSE permutation at op #{ i } : perms={ perms } , rank={ rank } "
230- expected_out_shape = [in_shape [perm ] for perm in perms ]
231- assert expected_out_shape == out_shape , (
232- f"Invalid TRANSPOSE shape mapping at op #{ i } : "
233- f"perms={ perms } , in_shape={ in_shape } , out_shape={ out_shape } , "
234- f"expected_out_shape={ expected_out_shape } "
235- )
236- if rank > 4 :
237- high_rank_transpose_count += 1
238-
239- return high_rank_transpose_count
240-
241-
242- @common .parametrize ("model" , _build_cases ())
243- def test_high_rank_permute_view_tosa_INT_transpose_invariants (
244- model : torch .nn .Module , tmp_path
245- ):
246- out_dir = tmp_path / "high_rank_permute_view_fuzz"
247- out_dir .mkdir (parents = True , exist_ok = True )
248- tosa_path = _run_model (model , str (out_dir ))
249- assert tosa_path .exists (), f"Missing TOSA dump: { tosa_path } "
250- high_rank_count = _assert_transpose_invariants (tosa_path )
251- assert (
252- high_rank_count > 0
253- ), "Expected at least one rank>4 TRANSPOSE in generated case."
0 commit comments