1- # Copyright 2024-2025 Arm Limited and/or its affiliates.
1+ # Copyright 2024-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
@@ -282,7 +282,6 @@ def tosa_support_factory(
282282 # Negative checks: Remove nodes from partitioning
283283 negative_checks : list [OperatorSupportBase ] = [
284284 CheckInt64InputsAndOutputs (exported_program , reporter ),
285- CheckFloat64Inputs (exported_program , reporter ),
286285 RankCheck (reporter , max_rank = MAX_RANK ),
287286 * [
288287 reporter .wrap_check (check , f"Rejected by { check .__class__ .__name__ } " )
@@ -293,6 +292,15 @@ def tosa_support_factory(
293292 if not tosa_spec .support_float ():
294293 negative_checks .append (CheckArmQuantized (reporter ))
295294 negative_checks .append (CheckProperQuantization (reporter ))
295+
296+ disallowed_dtypes = [torch .float64 ]
297+ if not tosa_spec .support_extension ("bf16" ):
298+ disallowed_dtypes .append (torch .bfloat16 )
299+ negative_checks .append (
300+ CheckDtypeInputsAndOutputs (
301+ exported_program , reporter , disallowed_dtypes , tosa_spec
302+ )
303+ )
296304 if tosa_spec .is_U55_subset :
297305 negative_checks .append (EthosU55NotSupported (reporter ))
298306 negative_checks .append (EthosU55DtypeSupport (reporter ))
@@ -657,24 +665,26 @@ def is_node_supported(
657665 return True
658666
659667
660- class CheckFloat64Inputs (OperatorSupportBase ):
661- """Reject nodes with float64 inputs.
662-
663- Useful as a negative check for specs that do not allow float64.
664-
665- """
668+ class CheckDtypeInputsAndOutputs (OperatorSupportBase ):
669+ """Reject nodes with at least one disallowed dtype on inputs or outputs."""
666670
667671 def __init__ (
668- self , exported_program : ExportedProgram , reporter : WhyNoPartitionReporter
672+ self ,
673+ exported_program : ExportedProgram ,
674+ reporter : WhyNoPartitionReporter ,
675+ disallowed_dtypes : list [torch .dtype ],
676+ tosa_spec : TosaSpecification ,
669677 ):
670678 """Initialize the check with program context and reporter."""
671679 self .reporter = reporter
680+ self .disallowed_dtypes = disallowed_dtypes
681+ self .tosa_spec = tosa_spec
672682 super ().__init__ ()
673683
674684 def is_node_supported (
675685 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
676686 ) -> bool :
677- """Return True if no float64 inputs are present."""
687+ """Return True if no disallowed dtypes are present on inputs or outputs ."""
678688 if is_submodule_node (node ):
679689 return True
680690 for input_node in (
@@ -683,10 +693,29 @@ def is_node_supported(
683693 if input_node .op != "get_attr"
684694 ):
685695 tensor = get_first_fake_tensor (input_node )
686- if tensor .dtype == torch .float64 :
696+ if tensor .dtype in self .disallowed_dtypes :
697+ self .reporter .report_reject (
698+ node ,
699+ f"Had { tensor .dtype } input { input_node .name } that is not supported by { self .tosa_spec } ." ,
700+ )
701+ return False
702+
703+ meta_val = node .meta ["val" ]
704+ if isinstance (
705+ meta_val , (Sequence , torch .fx .immutable_collections .immutable_list )
706+ ):
707+ outputs = meta_val
708+ else :
709+ outputs = (meta_val ,)
710+
711+ for output in outputs :
712+ if (
713+ isinstance (output , FakeTensor )
714+ and output .dtype in self .disallowed_dtypes
715+ ):
687716 self .reporter .report_reject (
688717 node ,
689- f"Had float64 input { input_node . name } that couldn't be handled ." ,
718+ f"Had { output . dtype } output that is not supported by { self . tosa_spec } ." ,
690719 )
691720 return False
692721 return True
0 commit comments