-
Notifications
You must be signed in to change notification settings - Fork 961
Arm backend: Fix rejection criteria of TRANSPOSE from VIEW #19044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
|
|
||
| """ | ||
| import typing | ||
| from itertools import combinations | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
|
|
@@ -281,20 +282,37 @@ def __init__(self, reporter: WhyNoPartitionReporter): | |
|
|
||
| _MAX_AXIS_PRODUCT = 65536 | ||
|
|
||
| def axes_product(self, shape: shape_t) -> int: | ||
| """Return the product of all axes in ``shape``. | ||
|
|
||
| def _max_product_axis(self, shape: shape_t): | ||
| """ | ||
| Args: | ||
| shape (shape_t): Shape. | ||
|
|
||
| Returns: | ||
| int: Product of the axis sizes. | ||
|
|
||
| True if the TRANSPOSE can be run on the Ethos-U55 | ||
| False if the TRANSPOSE cannot be run on the Ethos-U55 | ||
|
|
||
| For a tensor of rank N, the product of any combination of | ||
| N - 2 axis needs to be less than 65536. E.g. for rank 4 tensor, | ||
| N*H, N*W, N*C, H*W, H*C, W*C should all be lower than 65536 to | ||
| be able to run the TRANSPOSE on Ethos-U55. | ||
| The full TRANSPOSE requirements for the Ethos-U55 are listed in | ||
| https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/SUPPORTED_OPS.md | ||
| """ | ||
| product = 1 | ||
| for axes in shape: | ||
| product *= axes | ||
| return product | ||
| rank = len(shape) | ||
| if rank < 3: | ||
| product = 1 | ||
| for idx in shape: | ||
| product *= idx | ||
| return product <= self._MAX_AXIS_PRODUCT | ||
|
|
||
| else: | ||
| for axes in combinations(range(rank), rank - 2): | ||
| product = 1 | ||
| for idx in axes: | ||
| product *= shape[idx] | ||
| if product > self._MAX_AXIS_PRODUCT: | ||
| return False | ||
| return True | ||
|
|
||
| def _check_rank_constraints( | ||
| self, | ||
|
|
@@ -322,11 +340,11 @@ def _check_rank_constraints( | |
| output_rank = len(output_shape) | ||
|
|
||
| if input_rank > 4: | ||
| if self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT: | ||
| if not (self._max_product_axis(input_shape)): | ||
| self.reporter.report_reject( | ||
| node, | ||
| f"Input may require transpose operator. No support for {input_shape=}, " | ||
| f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| ) | ||
| return False | ||
| if dtype == torch.int32: | ||
|
|
@@ -337,12 +355,12 @@ def _check_rank_constraints( | |
| return False | ||
|
|
||
| if output_rank > 4: | ||
| if self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT: | ||
| if not (self._max_product_axis(output_shape)): | ||
| shape = output_shape | ||
| self.reporter.report_reject( | ||
| node, | ||
| f"Operator may require transpose operator. No support for {shape=}, " | ||
| f"{dtype=}. Product of axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| ) | ||
| return False | ||
| if dtype == torch.int32: | ||
|
|
@@ -450,24 +468,22 @@ def _check_transpose_constraints( | |
| ) | ||
| return False | ||
|
|
||
| if ( | ||
| needs_input_transpose | ||
| and self.axes_product(input_shape) > self._MAX_AXIS_PRODUCT | ||
| ): | ||
| # For TRANSPOSE originating from a VIEW, we know we will only do | ||
| # NHWC -> NCHW or NCHW -> NHWC permutations, hence we only need to validate | ||
| # these two TRANSPOSEs. For the general case of any permutation on TRANSPOSE, | ||
| # we reason via the checks in EthosU55TransposeCheck | ||
| if needs_input_transpose and not (self._max_product_axis(input_shape)): | ||
| self.reporter.report_reject( | ||
|
Comment on lines
+471
to
476
|
||
| node, | ||
| f"Operator requires transpose operator. No support for {input_shape=}, " | ||
| f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", | ||
| f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| ) | ||
| return False | ||
| if ( | ||
| needs_output_transpose | ||
| and self.axes_product(output_shape) > self._MAX_AXIS_PRODUCT | ||
| ): | ||
| if needs_output_transpose and not (self._max_product_axis(output_shape)): | ||
| self.reporter.report_reject( | ||
| node, | ||
| f"Operator requires transpose operator. No support for {output_shape=}, " | ||
| f"{dtype=}. Product of axes must be <{self._MAX_AXIS_PRODUCT}", | ||
| f"{dtype=}. Product of any rank - 2 axes must be <={self._MAX_AXIS_PRODUCT}", | ||
| ) | ||
| return False | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -51,6 +51,8 @@ class View(torch.nn.Module): | |||||
| "rand_5d_5d": lambda: (torch.rand(1, 1, 4, 5, 6), (1, 1, 4, -1, 6)), | ||||||
| "rand_5d_3d": lambda: (torch.rand(1, 1, 4, 5, 6), (2, 3, -1)), | ||||||
| "rand_3d_5d": lambda: (torch.rand(4, 5, 6), (1, 1, 2, -1, 3)), | ||||||
| "rank4_rank3_large": lambda: (torch.rand(1, 256, 6, 48), (6, 48, 256)), | ||||||
| "rank5_rank4_large": lambda: (torch.rand(1, 256, 2, 3, 48), (1, 256, 6, 48)), | ||||||
| } | ||||||
|
|
||||||
| needs_transpose_tests_fp16 = { | ||||||
|
|
@@ -65,8 +67,7 @@ class View(torch.nn.Module): | |||||
| } | ||||||
|
|
||||||
| rank_product_too_large = { | ||||||
| "rand_4d_large": lambda: (torch.rand(1, 49, 16, 128), (1, 16, 49, 128)), | ||||||
| "rand_5d_large": lambda: (torch.rand(2, 25, 16, 8, 64), (2, 16, 25, 8, 64)), | ||||||
| "rand_5d_large": lambda: (torch.rand(2, 256, 512, 8, 64), (2, 512, 256, 8, 64)), | ||||||
|
||||||
| "rand_5d_large": lambda: (torch.rand(2, 256, 512, 8, 64), (2, 512, 256, 8, 64)), | |
| "rand_5d_large": lambda: (torch.rand(1, 257, 256, 1, 2), (1, 256, 257, 1, 2)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is expected. I am not testing the model can run on the NPU, i am testing if the TRANSPOSE is correctly rejected. The condition that the product of any 2 pairs of axes should exceed 2**16 to be rejected requires really big tensors to test the rejection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_max_product_axis’s docstring is misleading/inaccurate: it says it “Return[s] shape padded to rank4” but the function actually returns a boolean, and it doesn’t pad the shape. Consider renaming the helper to reflect that it validates axis-product constraints (and add a-> boolreturn type), and update the docstring to describe the boolean semantics and the rule being enforced.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ACK will fix.