Skip to content

Commit a4737c6

Browse files
martinlsmMartin LindströmCopilot
authored
Arm backend: Remove code duplication in Ethos-U pipelines (#16579)
- EthosU55PipelineINT and EthosU85PipelineINT contained copied code in their respective `__init__` functions. Introduce a common super class EthosUPipelineINTBase that runs the common code upon initialization. - Some pipeline class names have the suffix "Maker". Remove this suffix to make naming of the classes consistent. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 183aff3 commit a4737c6

File tree

1 file changed

+94
-100
lines changed

1 file changed

+94
-100
lines changed

backends/arm/test/tester/test_pipeline.py

Lines changed: 94 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,6 +22,7 @@
2222

2323
import torch
2424
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
25+
from executorch.backends.arm.ethosu import EthosUCompileSpec
2526

2627
from executorch.backends.arm.quantizer import (
2728
EthosUQuantizer,
@@ -88,9 +89,9 @@ def update(self, *args, **kwargs):
8889
raise RuntimeError(f"{self.id} args updated after being called.")
8990

9091

91-
class BasePipelineMaker(Generic[T]):
92+
class BasePipeline(Generic[T]):
9293
"""
93-
The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it
94+
The BasePipeline defines a list of stages to be applied to a torch.nn.module for lowering it
9495
in the Arm backend. To be inherited and adjusted for particular targets. Importantly, the
9596
pipeline list can be modified before running the pipeline to support various pipeline extensions
9697
and debugging usecases.
@@ -317,7 +318,7 @@ def run(self):
317318
raise e
318319

319320

320-
class TOSAPipelineMaker(BasePipelineMaker, Generic[T]):
321+
class TOSAPipeline(BasePipeline, Generic[T]):
321322
@staticmethod
322323
def is_tosa_ref_model_available():
323324
"""Checks if the TOSA reference model is available."""
@@ -343,7 +344,7 @@ def run(self):
343344
super().run()
344345

345346

346-
class TosaPipelineINT(TOSAPipelineMaker, Generic[T]):
347+
class TosaPipelineINT(TOSAPipeline, Generic[T]):
347348
"""
348349
Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model.
349350
@@ -475,7 +476,7 @@ def __init__(
475476
)
476477

477478

478-
class TosaPipelineFP(TOSAPipelineMaker, Generic[T]):
479+
class TosaPipelineFP(TOSAPipeline, Generic[T]):
479480
"""
480481
Lowers a graph to FP TOSA spec and tests it with the TOSA reference model.
481482
@@ -558,44 +559,35 @@ def __init__(
558559
)
559560

560561

561-
class EthosU55PipelineINT(BasePipelineMaker, Generic[T]):
562-
"""
563-
Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true.
564-
565-
Attributes:
566-
module: The module which the pipeline is applied to.
567-
test_data: Data used for quantizing and testing the module.
568-
aten_ops: Aten dialect ops expected to be found in the graph after export.
569-
570-
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
571-
if not using use_edge_to_transform_and_lower.
572-
run_on_fvp: Set to true to test the pte fileon a fvp simulator.
573-
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
574-
custom_path : Path to dump intermediate artifacts such as tosa and pte to.
575-
"""
562+
class EthosUPipelineINTBase(BasePipeline, Generic[T]):
563+
"""Base class that encapsulates shared Ethos-U INT pipeline setup."""
576564

577565
def __init__(
578566
self,
567+
compile_spec: EthosUCompileSpec,
579568
module: torch.nn.Module,
580569
test_data: T,
581570
aten_ops: str | List[str],
582-
exir_ops: Optional[str | List[str]] = None,
571+
exir_ops: str | Sequence[str] | None,
583572
run_on_fvp: bool = True,
584573
symmetric_io_quantization: bool = False,
585574
per_channel_quantization: bool = True,
586575
a16w8_quantization: bool = False,
587576
use_to_edge_transform_and_lower: bool = True,
588-
custom_path: str | None = None,
589-
tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None,
590577
atol: float = 1e-03,
591578
rtol: float = 1e-03,
592579
qtol: int = 1,
593580
epsilon: float = 2**-12,
594581
):
595-
compile_spec = common.get_u55_compile_spec(
596-
custom_path=custom_path,
597-
tosa_debug_mode=tosa_debug_mode,
582+
super().__init__(
583+
module,
584+
test_data,
585+
aten_ops,
586+
compile_spec,
587+
exir_ops,
588+
use_to_edge_transform_and_lower,
598589
)
590+
599591
quantizer = EthosUQuantizer(compile_spec)
600592
# choose int8 or int16 activation quantization
601593
if a16w8_quantization:
@@ -610,15 +602,6 @@ def __init__(
610602
quantizer.set_io(quantization_config)
611603
quant_stage = Quantize(quantizer, quantization_config)
612604

613-
super().__init__(
614-
module,
615-
test_data,
616-
aten_ops,
617-
compile_spec,
618-
exir_ops,
619-
use_to_edge_transform_and_lower,
620-
)
621-
622605
self.add_stage(self.tester.quantize, quant_stage, pos=0)
623606

624607
remove_quant_nodes_stage = (
@@ -659,19 +642,19 @@ def __init__(
659642
)
660643

661644

662-
class EthosU85PipelineINT(BasePipelineMaker, Generic[T]):
645+
class EthosU55PipelineINT(EthosUPipelineINTBase, Generic[T]):
663646
"""
664-
Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true.
647+
Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true.
665648
666649
Attributes:
667650
module: The module which the pipeline is applied to.
668651
test_data: Data used for quantizing and testing the module.
669652
aten_ops: Aten dialect ops expected to be found in the graph after export.
670653
671-
exir_ops: Exir dialect ops expected to be found in the graph after to_edge if not using
672-
use_edge_to_transform_and_lower.
673-
run_on_fvp: Set to true to test the pte fileon a fvp simulator.
674-
use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
654+
exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
655+
if not using use_edge_to_transform_and_lower.
656+
run_on_fvp: Set to true to test the pte file on a fvp simulator.
657+
use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module.
675658
custom_path : Path to dump intermediate artifacts such as tosa and pte to.
676659
"""
677660

@@ -680,7 +663,7 @@ def __init__(
680663
module: torch.nn.Module,
681664
test_data: T,
682665
aten_ops: str | List[str],
683-
exir_ops: str | List[str] | None = None,
666+
exir_ops: str | Sequence[str] | None = None,
684667
run_on_fvp: bool = True,
685668
symmetric_io_quantization: bool = False,
686669
per_channel_quantization: bool = True,
@@ -693,74 +676,85 @@ def __init__(
693676
qtol: int = 1,
694677
epsilon: float = 2**-12,
695678
):
696-
compile_spec = common.get_u85_compile_spec(
679+
compile_spec = common.get_u55_compile_spec(
697680
custom_path=custom_path,
698681
tosa_debug_mode=tosa_debug_mode,
699682
)
700-
quantizer = EthosUQuantizer(compile_spec)
701-
# choose int8 or int16 activation quantization
702-
if a16w8_quantization:
703-
quantization_config = get_symmetric_a16w8_quantization_config(
704-
is_per_channel=per_channel_quantization, epsilon=epsilon
705-
)
706-
else:
707-
quantization_config = get_symmetric_quantization_config(
708-
is_per_channel=per_channel_quantization
709-
)
710-
if symmetric_io_quantization:
711-
quantizer.set_io(quantization_config)
712-
quant_stage = Quantize(quantizer, quantization_config)
713-
714683
super().__init__(
684+
compile_spec,
715685
module,
716686
test_data,
717687
aten_ops,
718-
compile_spec,
719688
exir_ops,
720-
use_to_edge_transform_and_lower,
689+
run_on_fvp=run_on_fvp,
690+
symmetric_io_quantization=symmetric_io_quantization,
691+
per_channel_quantization=per_channel_quantization,
692+
a16w8_quantization=a16w8_quantization,
693+
use_to_edge_transform_and_lower=use_to_edge_transform_and_lower,
694+
atol=atol,
695+
rtol=rtol,
696+
qtol=qtol,
697+
epsilon=epsilon,
721698
)
722699

723-
self.add_stage(self.tester.quantize, quant_stage, pos=0)
724700

725-
remove_quant_nodes_stage = (
726-
"to_edge_transform_and_lower"
727-
if use_to_edge_transform_and_lower
728-
else "partition"
729-
)
701+
class EthosU85PipelineINT(EthosUPipelineINTBase, Generic[T]):
702+
"""
703+
Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true.
730704
731-
if _has_quantizable_inputs(test_data):
732-
# only add stages if we have quantizable input
733-
self.add_stage_after(
734-
"quantize",
735-
self.tester.check,
736-
[
737-
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
738-
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
739-
],
740-
suffix="quant_nodes",
741-
)
742-
self.add_stage_after(
743-
remove_quant_nodes_stage,
744-
self.tester.check_not,
745-
[
746-
"torch.ops.quantized_decomposed.dequantize_per_tensor.default",
747-
"torch.ops.quantized_decomposed.quantize_per_tensor.default",
748-
],
749-
suffix="quant_nodes",
750-
)
705+
Attributes:
706+
module: The module which the pipeline is applied to.
707+
test_data: Data used for quantizing and testing the module.
708+
aten_ops: Aten dialect ops expected to be found in the graph after export.
751709
752-
if run_on_fvp:
753-
self.add_stage(self.tester.serialize)
754-
self.add_stage(
755-
self.tester.run_method_and_compare_outputs,
756-
atol=atol,
757-
rtol=rtol,
758-
qtol=qtol,
759-
inputs=self.test_data,
760-
)
710+
exir_ops: Exir dialect ops expected to be found in the graph after to_edge if not using
711+
use_edge_to_transform_and_lower.
712+
run_on_fvp: Set to true to test the pte file on a fvp simulator.
713+
use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module.
714+
custom_path : Path to dump intermediate artifacts such as tosa and pte to.
715+
"""
716+
717+
def __init__(
718+
self,
719+
module: torch.nn.Module,
720+
test_data: T,
721+
aten_ops: str | List[str],
722+
exir_ops: str | Sequence[str] | None = None,
723+
run_on_fvp: bool = True,
724+
symmetric_io_quantization: bool = False,
725+
per_channel_quantization: bool = True,
726+
a16w8_quantization: bool = False,
727+
use_to_edge_transform_and_lower: bool = True,
728+
custom_path: str | None = None,
729+
tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None,
730+
atol: float = 1e-03,
731+
rtol: float = 1e-03,
732+
qtol: int = 1,
733+
epsilon: float = 2**-12,
734+
):
735+
compile_spec = common.get_u85_compile_spec(
736+
custom_path=custom_path,
737+
tosa_debug_mode=tosa_debug_mode,
738+
)
739+
super().__init__(
740+
compile_spec,
741+
module,
742+
test_data,
743+
aten_ops,
744+
exir_ops,
745+
run_on_fvp=run_on_fvp,
746+
symmetric_io_quantization=symmetric_io_quantization,
747+
per_channel_quantization=per_channel_quantization,
748+
a16w8_quantization=a16w8_quantization,
749+
use_to_edge_transform_and_lower=use_to_edge_transform_and_lower,
750+
atol=atol,
751+
rtol=rtol,
752+
qtol=qtol,
753+
epsilon=epsilon,
754+
)
761755

762756

763-
class PassPipeline(TOSAPipelineMaker, Generic[T]):
757+
class PassPipeline(TOSAPipeline, Generic[T]):
764758
"""
765759
Runs single passes directly on an edge_program and checks operators before/after.
766760
@@ -858,7 +852,7 @@ def run(self):
858852
super().run()
859853

860854

861-
class TransformAnnotationPassPipeline(TOSAPipelineMaker, Generic[T]):
855+
class TransformAnnotationPassPipeline(TOSAPipeline, Generic[T]):
862856
"""
863857
Runs transform_for_annotation_pipeline passes directly on an exported program and checks output.
864858
@@ -914,7 +908,7 @@ def __init__(
914908
)
915909

916910

917-
class QuantizationPipeline(TOSAPipelineMaker, Generic[T]):
911+
class QuantizationPipeline(TOSAPipeline, Generic[T]):
918912
"""
919913
Runs quantization and checks that appropriate nodes are annotated with an expected
920914
quantization-spec.
@@ -971,7 +965,7 @@ def __init__(
971965
)
972966

973967

974-
class OpNotSupportedPipeline(TOSAPipelineMaker, Generic[T]):
968+
class OpNotSupportedPipeline(TOSAPipeline, Generic[T]):
975969
"""
976970
Runs the partitioner on a module and checks that ops are not delegated to test
977971
SupportedTOSAOperatorChecks.
@@ -1040,7 +1034,7 @@ def __init__(
10401034
self.pop_stage("to_executorch")
10411035

10421036

1043-
class VgfPipeline(BasePipelineMaker, Generic[T]):
1037+
class VgfPipeline(BasePipeline, Generic[T]):
10441038
"""
10451039
Lowers a graph based on TOSA spec (with or without quantization) and converts TOSA to VFG.
10461040

0 commit comments

Comments
 (0)