Skip to content

Commit 656850a

Browse files
Arm backend: Add test transpose count suite (#19022)
This gives a baseline of how many transposes are inserted in the graph for a number of important decompositions. Also updates test_high_rank_permute_view_invariants to use the same way of counting transposes instead of a custom solution. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 6be4fb5 commit 656850a

2 files changed

Lines changed: 591 additions & 100 deletions

File tree

backends/arm/test/misc/test_high_rank_permute_view_invariants.py

Lines changed: 33 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,31 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import random
7-
from pathlib import Path
8-
from typing import Any
7+
from dataclasses import dataclass
8+
from typing import Any, Tuple
99

1010
import torch
11-
import torch.nn as nn
11+
1212
from executorch.backends.arm.test import common
1313
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
14-
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
14+
15+
16+
InputT = Tuple[Any, ...]
1517

1618

1719
class HighRankPermuteViewModel(torch.nn.Module):
1820
def __init__(self, ops: list[tuple[str, Any]]):
1921
super().__init__()
2022
self.ops = ops
21-
self.block = nn.Sequential(
22-
nn.Conv2d(
23+
self.block = torch.nn.Sequential(
24+
torch.nn.Conv2d(
2325
in_channels=3,
2426
out_channels=64,
2527
kernel_size=3,
2628
stride=2,
2729
padding=1,
2830
),
29-
nn.ReLU(),
31+
torch.nn.ReLU(),
3032
)
3133

3234
def forward(self, x):
@@ -41,6 +43,13 @@ def forward(self, x):
4143
return x
4244

4345

46+
@dataclass(frozen=True)
47+
class TransposeInvariantCase:
48+
module: torch.nn.Module
49+
inputs: InputT
50+
expected_transposes: int
51+
52+
4453
def _random_non_identity_permutation(
4554
rng: random.Random, rank: int
4655
) -> tuple[int, ...] | None:
@@ -130,7 +139,6 @@ def _generate_chain(
130139
shape = new_shape
131140
break
132141

133-
# Ensure each case has at least one rank>4 permute.
134142
while len(shape) <= 4:
135143
new_shape = _reshape_add_singleton(rng, shape)
136144
if new_shape is None:
@@ -146,108 +154,33 @@ def _generate_chain(
146154
return ops
147155

148156

149-
def _build_cases() -> dict[str, HighRankPermuteViewModel]:
157+
def _build_high_rank_permute_cases() -> dict[str, TransposeInvariantCase]:
150158
rng = random.Random(
151159
20260225
152160
) # nosec B311: deterministic RNG for test case generation
153-
start_shape = [1, 16, 16, 64] # conv output from input 1x3x32x32 after NHWC permute
154-
cases: dict[str, HighRankPermuteViewModel] = {}
161+
start_shape = [1, 16, 16, 64]
162+
expected_transpose_counts = [6, 11, 10, 10, 7, 7, 10, 10, 8, 10]
163+
cases: dict[str, TransposeInvariantCase] = {}
155164
for idx in range(10):
156165
ops = _generate_chain(rng, start_shape, steps=8)
157-
cases[f"fuzz_case_{idx}"] = HighRankPermuteViewModel(ops)
166+
cases[f"high_rank_permute_fuzz_case_{idx}"] = TransposeInvariantCase(
167+
module=HighRankPermuteViewModel(ops).eval(),
168+
inputs=(torch.randn(1, 3, 32, 32),),
169+
expected_transposes=expected_transpose_counts[idx],
170+
)
158171
return cases
159172

160173

161-
def _run_model(model: torch.nn.Module, out_dir: str) -> Path:
162-
sample = torch.randn(1, 3, 32, 32)
163-
pipeline = TosaPipelineINT[tuple[torch.Tensor]](
164-
model.eval(),
165-
(sample,),
174+
@common.parametrize("case", _build_high_rank_permute_cases())
175+
def test_transpose_invariants_tosa_INT_high_rank_permute_view(
176+
case: TransposeInvariantCase,
177+
) -> None:
178+
pipeline = TosaPipelineINT[InputT](
179+
case.module,
180+
case.inputs,
166181
aten_op=[],
167182
exir_op=[],
168183
run_on_tosa_ref_model=False,
169-
custom_path=out_dir,
170-
tosa_debug_mode=TosaCompileSpec.DebugMode.JSON,
171-
tosa_extensions=["int16", "int4", "cf"],
172184
)
185+
pipeline.count_tosa_ops({"TRANSPOSE": case.expected_transposes})
173186
pipeline.run()
174-
175-
tosa_files = sorted(Path(out_dir).glob("*.tosa"))
176-
assert tosa_files, f"No TOSA artifacts found in {out_dir}"
177-
return tosa_files[0]
178-
179-
180-
def _assert_transpose_invariants(tosa_path: Path) -> int:
181-
import tosa.Op as Op # type: ignore[import-not-found,import-untyped]
182-
from tosa.TosaGraph import ( # type: ignore[import-not-found,import-untyped]
183-
TosaGraph,
184-
)
185-
from tosa.TransposeAttribute import ( # type: ignore[import-not-found,import-untyped]
186-
TransposeAttribute,
187-
)
188-
189-
graph = TosaGraph.GetRootAs(tosa_path.read_bytes(), 0)
190-
block = graph.Regions(0).Blocks(0)
191-
192-
shape_by_name = {
193-
block.Tensors(i).Name().decode(): list(block.Tensors(i).ShapeAsNumpy())
194-
for i in range(block.TensorsLength())
195-
}
196-
197-
op_enum = Op.Op()
198-
op_value_to_name = {
199-
getattr(op_enum, name): name for name in dir(op_enum) if name.isupper()
200-
}
201-
202-
high_rank_transpose_count = 0
203-
for i in range(block.OperatorsLength()):
204-
op = block.Operators(i)
205-
if op_value_to_name.get(op.Op()) != "TRANSPOSE":
206-
continue
207-
208-
inputs = [op.Inputs(j).decode() for j in range(op.InputsLength())]
209-
outputs = [op.Outputs(j).decode() for j in range(op.OutputsLength())]
210-
assert len(inputs) == 1 and len(outputs) == 1, (
211-
f"Unexpected TRANSPOSE arity at op #{i}: "
212-
f"{len(inputs)} inputs, {len(outputs)} outputs"
213-
)
214-
215-
attr_tbl = op.Attribute()
216-
transpose_attr = TransposeAttribute()
217-
transpose_attr.Init(attr_tbl.Bytes, attr_tbl.Pos)
218-
perms = [int(perm) for perm in transpose_attr.PermsAsNumpy()]
219-
220-
in_shape = [int(v) for v in shape_by_name[inputs[0]]]
221-
out_shape = [int(v) for v in shape_by_name[outputs[0]]]
222-
223-
rank = len(in_shape)
224-
assert (
225-
len(perms) == rank
226-
), f"Invalid TRANSPOSE rank at op #{i}: len(perms)={len(perms)} rank={rank}"
227-
assert sorted(perms) == list(
228-
range(rank)
229-
), f"Invalid TRANSPOSE permutation at op #{i}: perms={perms}, rank={rank}"
230-
expected_out_shape = [in_shape[perm] for perm in perms]
231-
assert expected_out_shape == out_shape, (
232-
f"Invalid TRANSPOSE shape mapping at op #{i}: "
233-
f"perms={perms}, in_shape={in_shape}, out_shape={out_shape}, "
234-
f"expected_out_shape={expected_out_shape}"
235-
)
236-
if rank > 4:
237-
high_rank_transpose_count += 1
238-
239-
return high_rank_transpose_count
240-
241-
242-
@common.parametrize("model", _build_cases())
243-
def test_high_rank_permute_view_tosa_INT_transpose_invariants(
244-
model: torch.nn.Module, tmp_path
245-
):
246-
out_dir = tmp_path / "high_rank_permute_view_fuzz"
247-
out_dir.mkdir(parents=True, exist_ok=True)
248-
tosa_path = _run_model(model, str(out_dir))
249-
assert tosa_path.exists(), f"Missing TOSA dump: {tosa_path}"
250-
high_rank_count = _assert_transpose_invariants(tosa_path)
251-
assert (
252-
high_rank_count > 0
253-
), "Expected at least one rank>4 TRANSPOSE in generated case."

0 commit comments

Comments
 (0)