@@ -336,7 +336,7 @@ def _negative_checks(
336336 checks : list [OperatorSupportBase ] = [RankCheck (reporter , MAX_RANK )]
337337
338338 if not tosa_spec .support_extension ("int64" ):
339- checks .append (CheckInt64InputsAndOutputs (exported_program , reporter ))
339+ checks .append (CheckInt64InputsAndOutputs (exported_program , reporter , tosa_spec ))
340340
341341 checks .extend (_wrapped_additional_checks (additional_checks , reporter ))
342342
@@ -683,7 +683,10 @@ class CheckInt64InputsAndOutputs(OperatorSupportBase):
683683 """
684684
685685 def __init__ (
686- self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
686+ self ,
687+ exported_program : ExportedProgram ,
688+ reporter : WhyNoPartitionReporter ,
689+ tosa_spec : TosaSpecification ,
687690 ):
688691 """Initialize the check with program context and reporter."""
689692 self .input_names = [
@@ -692,6 +695,7 @@ def __init__(
692695 if spec .kind == InputKind .USER_INPUT
693696 ]
694697 self .reporter = reporter
698+ self .tosa_spec = tosa_spec
695699 self .int32_min = torch .iinfo (torch .int32 ).min
696700 self .int32_max = torch .iinfo (torch .int32 ).max
697701 super ().__init__ ()
@@ -704,6 +708,104 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
704708 min_val , max_val = int (torch .min (data )), int (torch .max (data ))
705709 return min_val >= self .int32_min and max_val <= self .int32_max
706710
711+ def has_rejected_int64_output (
712+ self , node : torch .fx .Node , tensor_list : Sequence [typing .Any ]
713+ ) -> bool :
714+ if node .target in (
715+ torch .ops .aten .argmax .default ,
716+ exir_ops .edge .aten .argmax .default ,
717+ ):
718+ return not self ._is_tosa_argmax_supported (node )
719+ return any (
720+ tensor .dtype == torch .int64
721+ for tensor in tensor_list
722+ if isinstance (tensor , FakeTensor )
723+ )
724+
725+ def _is_tosa_argmax_dtype_supported (
726+ self , node : torch .fx .Node , input_dtype : torch .dtype
727+ ) -> bool :
728+ if input_dtype == torch .int8 :
729+ if not self .tosa_spec .support_integer ():
730+ self .reporter .report_reject (
731+ node , "TOSA ARGMAX requires PRO-INT for int8 input."
732+ )
733+ return False
734+ elif input_dtype == torch .int16 :
735+ if not (
736+ self .tosa_spec .support_integer ()
737+ and self .tosa_spec .support_extension ("int16" )
738+ ):
739+ self .reporter .report_reject (
740+ node , "TOSA ARGMAX requires EXT-INT16 for int16 input."
741+ )
742+ return False
743+ elif input_dtype in (torch .float16 , torch .float32 ):
744+ if not self .tosa_spec .support_float ():
745+ self .reporter .report_reject (
746+ node , f"TOSA ARGMAX requires PRO-FP for { input_dtype } input."
747+ )
748+ return False
749+ elif input_dtype == torch .bfloat16 :
750+ if not (
751+ self .tosa_spec .support_float ()
752+ and self .tosa_spec .support_extension ("bf16" )
753+ ):
754+ self .reporter .report_reject (
755+ node , "TOSA ARGMAX requires EXT-BF16 for bfloat16 input."
756+ )
757+ return False
758+ else :
759+ self .reporter .report_reject (
760+ node , f"TOSA ARGMAX does not support { input_dtype } input."
761+ )
762+ return False
763+ return True
764+
765+ def _is_tosa_argmax_supported (self , node : torch .fx .Node ) -> bool :
766+ dim = node .kwargs .get ("dim" , node .args [1 ] if len (node .args ) > 1 else None )
767+ if dim is None :
768+ self .reporter .report_reject (
769+ node , "TOSA ARGMAX requires an explicit reduction dimension."
770+ )
771+ return False
772+ if not isinstance (dim , int ):
773+ self .reporter .report_reject (
774+ node , "TOSA ARGMAX requires a statically known reduction dimension."
775+ )
776+ return False
777+
778+ input_node = typing .cast (torch .fx .Node , node .args [0 ])
779+ input_tensor = get_first_fake_tensor (input_node )
780+ if not self ._is_tosa_argmax_dtype_supported (node , input_tensor .dtype ):
781+ return False
782+
783+ input_rank = len (input_tensor .shape )
784+ if input_rank == 0 :
785+ self .reporter .report_reject (
786+ node , "TOSA ARGMAX requires an input with rank at least 1."
787+ )
788+ return False
789+
790+ axis = dim + input_rank if dim < 0 else dim
791+ if axis < 0 or axis >= input_rank :
792+ self .reporter .report_reject (
793+ node ,
794+ f"TOSA ARGMAX axis must be in [0, { input_rank - 1 } ] but got { dim } ." ,
795+ )
796+ return False
797+
798+ keepdim = node .kwargs .get (
799+ "keepdim" , node .args [2 ] if len (node .args ) > 2 else False
800+ )
801+ if keepdim :
802+ self .reporter .report_reject (
803+ node , "TOSA ARGMAX does not support keepdim=True."
804+ )
805+ return False
806+
807+ return True
808+
707809 def _check_int64_input_nodes (self , node : torch .fx .Node ) -> bool :
708810 """Check if all int64 input nodes are constant and will be
709811 partitioned.
@@ -747,11 +849,7 @@ def is_node_supported(
747849 vals = node .meta ["val" ]
748850 tensor_list = vals if isinstance (vals , (list , tuple )) else [vals ]
749851
750- any_int64 = any (
751- tensor .dtype == torch .int64
752- for tensor in tensor_list
753- if isinstance (tensor , FakeTensor )
754- )
852+ any_int64 = self .has_rejected_int64_output (node , tensor_list )
755853 # Don't partition nodes with int64 output...
756854 if any_int64 :
757855 # ... Except for constant ops that are directly cast to something non-int64.
0 commit comments