1010
1111"""
1212import typing
13+ from itertools import combinations
1314
1415import torch
1516import torch .fx as fx
@@ -281,20 +282,37 @@ def __init__(self, reporter: WhyNoPartitionReporter):
281282
282283 _MAX_AXIS_PRODUCT = 65536
283284
284- def axes_product (self , shape : shape_t ) -> int :
285- """Return the product of all axes in ``shape``.
286-
285+ def _max_product_axis (self , shape : shape_t ):
286+ """
287287 Args:
288288 shape (shape_t): Shape.
289289
290290 Returns:
291- int: Product of the axis sizes.
292-
291+ True if the TRANSPOSE can be run on the Ethos-U55
292+ False if the TRANSPOSE cannot be run on the Ethos-U55
293+
294+ For a tensor of rank N, the product of any combination of
295+ N - 2 axis needs to be less than 65536. E.g. for rank 4 tensor,
296+ N*H, N*W, N*C, H*W, H*C, W*C should all be lower than 65536 to
297+ be able to run the TRANSPOSE on Ethos-U55.
298+ The full TRANSPOSE requirements for the Ethos-U55 are listed in
299+ https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/blob/main/SUPPORTED_OPS.md
293300 """
294- product = 1
295- for axes in shape :
296- product *= axes
297- return product
301+ rank = len (shape )
302+ if rank < 3 :
303+ product = 1
304+ for idx in shape :
305+ product *= idx
306+ return product <= self ._MAX_AXIS_PRODUCT
307+
308+ else :
309+ for axes in combinations (range (rank ), rank - 2 ):
310+ product = 1
311+ for idx in axes :
312+ product *= shape [idx ]
313+ if product > self ._MAX_AXIS_PRODUCT :
314+ return False
315+ return True
298316
299317 def _check_rank_constraints (
300318 self ,
@@ -322,11 +340,11 @@ def _check_rank_constraints(
322340 output_rank = len (output_shape )
323341
324342 if input_rank > 4 :
325- if self .axes_product (input_shape ) > self . _MAX_AXIS_PRODUCT :
343+ if not ( self ._max_product_axis (input_shape )) :
326344 self .reporter .report_reject (
327345 node ,
328346 f"Input may require transpose operator. No support for { input_shape = } , "
329- f"{ dtype = } . Product of axes must be <={ self ._MAX_AXIS_PRODUCT } " ,
347+ f"{ dtype = } . Product of any rank - 2 axes must be <={ self ._MAX_AXIS_PRODUCT } " ,
330348 )
331349 return False
332350 if dtype == torch .int32 :
@@ -337,12 +355,12 @@ def _check_rank_constraints(
337355 return False
338356
339357 if output_rank > 4 :
340- if self .axes_product (output_shape ) > self . _MAX_AXIS_PRODUCT :
358+ if not ( self ._max_product_axis (output_shape )) :
341359 shape = output_shape
342360 self .reporter .report_reject (
343361 node ,
344362 f"Operator may require transpose operator. No support for { shape = } , "
345- f"{ dtype = } . Product of axes must be <={ self ._MAX_AXIS_PRODUCT } " ,
363+ f"{ dtype = } . Product of any rank - 2 axes must be <={ self ._MAX_AXIS_PRODUCT } " ,
346364 )
347365 return False
348366 if dtype == torch .int32 :
@@ -450,24 +468,22 @@ def _check_transpose_constraints(
450468 )
451469 return False
452470
453- if (
454- needs_input_transpose
455- and self .axes_product (input_shape ) > self ._MAX_AXIS_PRODUCT
456- ):
471+ # For TRANSPOSE originating from a VIEW, we know we will only do
472+ # NHWC -> NCHW or NCHW -> NHWC permutations, hence we only need to validate
473+ # these two TRANSPOSEs. For the general case of any permutation on TRANSPOSE,
474+ # we reason via the checks in EthosU55TransposeCheck
475+ if needs_input_transpose and not (self ._max_product_axis (input_shape )):
457476 self .reporter .report_reject (
458477 node ,
459478 f"Operator requires transpose operator. No support for { input_shape = } , "
460- f"{ dtype = } . Product of axes must be <{ self ._MAX_AXIS_PRODUCT } " ,
479+ f"{ dtype = } . Product of any rank - 2 axes must be <= { self ._MAX_AXIS_PRODUCT } " ,
461480 )
462481 return False
463- if (
464- needs_output_transpose
465- and self .axes_product (output_shape ) > self ._MAX_AXIS_PRODUCT
466- ):
482+ if needs_output_transpose and not (self ._max_product_axis (output_shape )):
467483 self .reporter .report_reject (
468484 node ,
469485 f"Operator requires transpose operator. No support for { output_shape = } , "
470- f"{ dtype = } . Product of axes must be <{ self ._MAX_AXIS_PRODUCT } " ,
486+ f"{ dtype = } . Product of any rank - 2 axes must be <= { self ._MAX_AXIS_PRODUCT } " ,
471487 )
472488 return False
473489
0 commit comments