Skip to content

Commit 16a80de

Browse files
authored
Arm backend: Extend fp64 partition check to more dtypes (#16703)
Check both inputs and outputs. Specifically add check for bf16 operators. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 8f9e0b2 commit 16a80de

1 file changed

Lines changed: 41 additions & 12 deletions

File tree

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -282,7 +282,6 @@ def tosa_support_factory(
282282
# Negative checks: Remove nodes from partitioning
283283
negative_checks: list[OperatorSupportBase] = [
284284
CheckInt64InputsAndOutputs(exported_program, reporter),
285-
CheckFloat64Inputs(exported_program, reporter),
286285
RankCheck(reporter, max_rank=MAX_RANK),
287286
*[
288287
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
@@ -293,6 +292,15 @@ def tosa_support_factory(
293292
if not tosa_spec.support_float():
294293
negative_checks.append(CheckArmQuantized(reporter))
295294
negative_checks.append(CheckProperQuantization(reporter))
295+
296+
disallowed_dtypes = [torch.float64]
297+
if not tosa_spec.support_extension("bf16"):
298+
disallowed_dtypes.append(torch.bfloat16)
299+
negative_checks.append(
300+
CheckDtypeInputsAndOutputs(
301+
exported_program, reporter, disallowed_dtypes, tosa_spec
302+
)
303+
)
296304
if tosa_spec.is_U55_subset:
297305
negative_checks.append(EthosU55NotSupported(reporter))
298306
negative_checks.append(EthosU55DtypeSupport(reporter))
@@ -657,24 +665,26 @@ def is_node_supported(
657665
return True
658666

659667

660-
class CheckFloat64Inputs(OperatorSupportBase):
661-
"""Reject nodes with float64 inputs.
662-
663-
Useful as a negative check for specs that do not allow float64.
664-
665-
"""
668+
class CheckDtypeInputsAndOutputs(OperatorSupportBase):
669+
"""Reject nodes with at least one disallowed dtype on inputs or outputs."""
666670

667671
def __init__(
668-
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
672+
self,
673+
exported_program: ExportedProgram,
674+
reporter: WhyNoPartitionReporter,
675+
disallowed_dtypes: list[torch.dtype],
676+
tosa_spec: TosaSpecification,
669677
):
670678
"""Initialize the check with program context and reporter."""
671679
self.reporter = reporter
680+
self.disallowed_dtypes = disallowed_dtypes
681+
self.tosa_spec = tosa_spec
672682
super().__init__()
673683

674684
def is_node_supported(
675685
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
676686
) -> bool:
677-
"""Return True if no float64 inputs are present."""
687+
"""Return True if no disallowed dtypes are present on inputs or outputs."""
678688
if is_submodule_node(node):
679689
return True
680690
for input_node in (
@@ -683,10 +693,29 @@ def is_node_supported(
683693
if input_node.op != "get_attr"
684694
):
685695
tensor = get_first_fake_tensor(input_node)
686-
if tensor.dtype == torch.float64:
696+
if tensor.dtype in self.disallowed_dtypes:
697+
self.reporter.report_reject(
698+
node,
699+
f"Had {tensor.dtype} input {input_node.name} that is not supported by {self.tosa_spec}.",
700+
)
701+
return False
702+
703+
meta_val = node.meta["val"]
704+
if isinstance(
705+
meta_val, (Sequence, torch.fx.immutable_collections.immutable_list)
706+
):
707+
outputs = meta_val
708+
else:
709+
outputs = (meta_val,)
710+
711+
for output in outputs:
712+
if (
713+
isinstance(output, FakeTensor)
714+
and output.dtype in self.disallowed_dtypes
715+
):
687716
self.reporter.report_reject(
688717
node,
689-
f"Had float64 input {input_node.name} that couldn't be handled.",
718+
f"Had {output.dtype} output that is not supported by {self.tosa_spec}.",
690719
)
691720
return False
692721
return True

0 commit comments

Comments
 (0)