Skip to content

Commit 1cc2d53

Browse files
committed
Arm backend: Fix rejection criteria of TRANSPOSE from VIEW
When delegating a VIEW for Ethos-U55, we were overly pessimistic whether we can delegate the TRANSPOSE that is needed for the NHWC -> NCHW or NCHW -> NHWC permutation. As a result, some RESHAPEs were left-over to the CPU when actually they could have been run on NPU. Signed-off-by: George Gekov <george.gekov@arm.com> Change-Id: I34cc3b38cf0dbb0ceee32ac5d0044805c4e1f085
1 parent 87e65ac commit 1cc2d53

2 files changed

Lines changed: 45 additions & 25 deletions

File tree

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
1111
"""
1212
import typing
13+
from itertools import combinations
1314

1415
import torch
1516
import torch.fx as fx
@@ -281,20 +282,37 @@ def __init__(self, reporter: WhyNoPartitionReporter):
281282

282283
_MAX_AXIS_PRODUCT = 65536
283284

284-
def axes_product(self, shape: shape_t) -> int:
285-
"""Return the product of all axes in ``shape``.
286-
285+
def _max_product_axis(self, shape: shape_t):
286+
"""
287287
Args:
288288
shape (shape_t): Shape.
289289
290290
Returns:
291-
int: Product of the axis sizes.
292-
291+
True if the TRANSPOSE can be run on the Ethos-U55
292+
False if the TRANSPOSE cannot be run on the Ethos-U55
293+
294+
For a tensor of rank N, the product of any combination of
295+
N - 2 axis needs to be less than 65536. E.g. for rank 4 tensor,
296+
N*H, N*W, N*C, H*W, H*C, W*C should all be lower than 65536 to
297+
be able to run the TRANSPOSE on Ethos-U55.
298+
The full TRANSPOSE requirements for the Ethos-U55 are listed in
299+
https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/SUPPORTED_OPS.md
293300
"""
294-
product = 1
295-
for axes in shape:
296-
product *= axes
297-
return product
301+
rank = len(shape)
302+
if rank < 3:
303+
product = 1
304+
for idx in shape:
305+
product *= idx
306+
return product <= self._MAX_AXIS_PRODUCT
307+
308+
else:
309+
for axes in combinations(range(rank), rank - 2):
310+
product = 1
311+
for idx in axes:
312+
product *= shape[idx]
313+
if product > self._MAX_AXIS_PRODUCT:
314+
return False
315+
return True
298316

299317
def _check_rank_constraints(
300318
self,
@@ -322,11 +340,11 @@ def _check_rank_constraints(
322340
output_rank = len(output_shape)
323341

324342
if input_rank > 4:
325-
if self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT:
343+
if not (self._max_product_axis(input_shape)):
326344
self.reporter.report_reject(
327345
node,
328346
f"Input may require transpose operator. No support for {input_shape=}, "
329-
f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}",
347+
f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}",
330348
)
331349
return False
332350
if dtype == torch.int32:
@@ -337,12 +355,12 @@ def _check_rank_constraints(
337355
return False
338356

339357
if output_rank > 4:
340-
if self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT:
358+
if not (self._max_product_axis(output_shape)):
341359
shape = output_shape
342360
self.reporter.report_reject(
343361
node,
344362
f"Operator may require transpose operator. No support for {shape=}, "
345-
f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}",
363+
f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}",
346364
)
347365
return False
348366
if dtype == torch.int32:
@@ -450,24 +468,22 @@ def _check_transpose_constraints(
450468
)
451469
return False
452470

453-
if (
454-
needs_input_transpose
455-
and self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT
456-
):
471+
# For TRANSPOSE originating from a VIEW, we know we will only do
472+
# NHWC -> NCHW or NCHW -> NHWC permutations, hence we only need to validate
473+
# these two TRANSPOSEs. For the general case of any permutation on TRANSPOSE,
474+
# we reason via the checks in EthosU55TransposeCheck
475+
if needs_input_transpose and not (self._max_product_axis(input_shape)):
457476
self.reporter.report_reject(
458477
node,
459478
f"Operator requires transpose operator. No support for {input_shape=}, "
460-
f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}",
479+
f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}",
461480
)
462481
return False
463-
if (
464-
needs_output_transpose
465-
and self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT
466-
):
482+
if needs_output_transpose and not (self._max_product_axis(output_shape)):
467483
self.reporter.report_reject(
468484
node,
469485
f"Operator requires transpose operator. No support for {output_shape=}, "
470-
f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}",
486+
f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}",
471487
)
472488
return False
473489

backends/arm/test/ops/test_view.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class View(torch.nn.Module):
5151
"rand_5d_5d": lambda: (torch.rand(1, 1, 4, 5, 6), (1, 1, 4, -1, 6)),
5252
"rand_5d_3d": lambda: (torch.rand(1, 1, 4, 5, 6), (2, 3, -1)),
5353
"rand_3d_5d": lambda: (torch.rand(4, 5, 6), (1, 1, 2, -1, 3)),
54+
"rank4_rank3_large": lambda: (torch.rand(1, 256, 6, 48), (6, 48, 256)),
55+
"rank5_rank4_large": lambda: (torch.rand(1, 256, 2, 3, 48), (1, 256, 6, 48)),
5456
}
5557

5658
needs_transpose_tests_fp16 = {
@@ -65,8 +67,7 @@ class View(torch.nn.Module):
6567
}
6668

6769
rank_product_too_large = {
68-
"rand_4d_large": lambda: (torch.rand(1, 49, 16, 128), (1, 16, 49, 128)),
69-
"rand_5d_large": lambda: (torch.rand(2, 25, 16, 8, 64), (2, 16, 25, 8, 64)),
70+
"rand_5d_large": lambda: (torch.rand(2, 256, 512, 8, 64), (2, 512, 256, 8, 64)),
7071
}
7172

7273
def __init__(self, new_shape):
@@ -116,6 +117,9 @@ def test_view_u55_INT(test_data: Tuple):
116117
aten_op,
117118
exir_ops=[],
118119
)
120+
pipeline.change_args(
121+
"check_not.exir", ["executorch_exir_dialects_edge__ops_aten_view_copy_default"]
122+
)
119123
pipeline.run()
120124

121125

0 commit comments

Comments
 (0)