1- # Copyright 2025 Arm Limited and/or its affiliates.
1+ # Copyright 2025-2026 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
2222
2323import torch
2424from executorch .backends .arm .common .arm_compile_spec import ArmCompileSpec
25+ from executorch .backends .arm .ethosu import EthosUCompileSpec
2526
2627from executorch .backends .arm .quantizer import (
2728 EthosUQuantizer ,
@@ -88,9 +89,9 @@ def update(self, *args, **kwargs):
8889 raise RuntimeError (f"{ self .id } args updated after being called." )
8990
9091
91- class BasePipelineMaker (Generic [T ]):
92+ class BasePipeline (Generic [T ]):
9293 """
93- The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it
94+ The BasePipeline defines a list of stages to be applied to a torch.nn.module for lowering it
9495 in the Arm backend. To be inherited and adjusted for particular targets. Importantly, the
9596 pipeline list can be modified before running the pipeline to support various pipeline extensions
9697 and debugging usecases.
@@ -317,7 +318,7 @@ def run(self):
317318 raise e
318319
319320
320- class TOSAPipelineMaker ( BasePipelineMaker , Generic [T ]):
321+ class TOSAPipeline ( BasePipeline , Generic [T ]):
321322 @staticmethod
322323 def is_tosa_ref_model_available ():
323324 """Checks if the TOSA reference model is available."""
@@ -343,7 +344,7 @@ def run(self):
343344 super ().run ()
344345
345346
346- class TosaPipelineINT (TOSAPipelineMaker , Generic [T ]):
347+ class TosaPipelineINT (TOSAPipeline , Generic [T ]):
347348 """
348349 Lowers a graph to INT TOSA spec (with quantization) and tests it with the TOSA reference model.
349350
@@ -475,7 +476,7 @@ def __init__(
475476 )
476477
477478
478- class TosaPipelineFP (TOSAPipelineMaker , Generic [T ]):
479+ class TosaPipelineFP (TOSAPipeline , Generic [T ]):
479480 """
480481 Lowers a graph to FP TOSA spec and tests it with the TOSA reference model.
481482
@@ -558,44 +559,35 @@ def __init__(
558559 )
559560
560561
561- class EthosU55PipelineINT (BasePipelineMaker , Generic [T ]):
562- """
563- Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true.
564-
565- Attributes:
566- module: The module which the pipeline is applied to.
567- test_data: Data used for quantizing and testing the module.
568- aten_ops: Aten dialect ops expected to be found in the graph after export.
569-
570- exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
571- if not using use_edge_to_transform_and_lower.
572- run_on_fvp: Set to true to test the pte fileon a fvp simulator.
573- use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
574- custom_path : Path to dump intermediate artifacts such as tosa and pte to.
575- """
562+ class EthosUPipelineINTBase (BasePipeline , Generic [T ]):
563+ """Base class that encapsulates shared Ethos-U INT pipeline setup."""
576564
577565 def __init__ (
578566 self ,
567+ compile_spec : EthosUCompileSpec ,
579568 module : torch .nn .Module ,
580569 test_data : T ,
581570 aten_ops : str | List [str ],
582- exir_ops : Optional [ str | List [str ]] = None ,
571+ exir_ops : str | Sequence [str ] | None ,
583572 run_on_fvp : bool = True ,
584573 symmetric_io_quantization : bool = False ,
585574 per_channel_quantization : bool = True ,
586575 a16w8_quantization : bool = False ,
587576 use_to_edge_transform_and_lower : bool = True ,
588- custom_path : str | None = None ,
589- tosa_debug_mode : Optional [ArmCompileSpec .DebugMode ] = None ,
590577 atol : float = 1e-03 ,
591578 rtol : float = 1e-03 ,
592579 qtol : int = 1 ,
593580 epsilon : float = 2 ** - 12 ,
594581 ):
595- compile_spec = common .get_u55_compile_spec (
596- custom_path = custom_path ,
597- tosa_debug_mode = tosa_debug_mode ,
582+ super ().__init__ (
583+ module ,
584+ test_data ,
585+ aten_ops ,
586+ compile_spec ,
587+ exir_ops ,
588+ use_to_edge_transform_and_lower ,
598589 )
590+
599591 quantizer = EthosUQuantizer (compile_spec )
600592 # choose int8 or int16 activation quantization
601593 if a16w8_quantization :
@@ -610,15 +602,6 @@ def __init__(
610602 quantizer .set_io (quantization_config )
611603 quant_stage = Quantize (quantizer , quantization_config )
612604
613- super ().__init__ (
614- module ,
615- test_data ,
616- aten_ops ,
617- compile_spec ,
618- exir_ops ,
619- use_to_edge_transform_and_lower ,
620- )
621-
622605 self .add_stage (self .tester .quantize , quant_stage , pos = 0 )
623606
624607 remove_quant_nodes_stage = (
@@ -659,19 +642,19 @@ def __init__(
659642 )
660643
661644
662- class EthosU85PipelineINT ( BasePipelineMaker , Generic [T ]):
645+ class EthosU55PipelineINT ( EthosUPipelineINTBase , Generic [T ]):
663646 """
664- Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true.
647+ Lowers a graph to u55 INT TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true.
665648
666649 Attributes:
667650 module: The module which the pipeline is applied to.
668651 test_data: Data used for quantizing and testing the module.
669652 aten_ops: Aten dialect ops expected to be found in the graph after export.
670653
671- exir_ops: Exir dialect ops expected to be found in the graph after to_edge if not using
672- use_edge_to_transform_and_lower.
673- run_on_fvp: Set to true to test the pte fileon a fvp simulator.
674- use_edge_to_transform_and_lower: Selects betweeen two possible ways of lowering the module.
654+ exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
655+ if not using use_edge_to_transform_and_lower.
656+ run_on_fvp: Set to true to test the pte file on a fvp simulator.
657+ use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module.
675658 custom_path : Path to dump intermediate artifacts such as tosa and pte to.
676659 """
677660
@@ -680,7 +663,7 @@ def __init__(
680663 module : torch .nn .Module ,
681664 test_data : T ,
682665 aten_ops : str | List [str ],
683- exir_ops : str | List [str ] | None = None ,
666+ exir_ops : str | Sequence [str ] | None = None ,
684667 run_on_fvp : bool = True ,
685668 symmetric_io_quantization : bool = False ,
686669 per_channel_quantization : bool = True ,
@@ -693,74 +676,85 @@ def __init__(
693676 qtol : int = 1 ,
694677 epsilon : float = 2 ** - 12 ,
695678 ):
696- compile_spec = common .get_u85_compile_spec (
679+ compile_spec = common .get_u55_compile_spec (
697680 custom_path = custom_path ,
698681 tosa_debug_mode = tosa_debug_mode ,
699682 )
700- quantizer = EthosUQuantizer (compile_spec )
701- # choose int8 or int16 activation quantization
702- if a16w8_quantization :
703- quantization_config = get_symmetric_a16w8_quantization_config (
704- is_per_channel = per_channel_quantization , epsilon = epsilon
705- )
706- else :
707- quantization_config = get_symmetric_quantization_config (
708- is_per_channel = per_channel_quantization
709- )
710- if symmetric_io_quantization :
711- quantizer .set_io (quantization_config )
712- quant_stage = Quantize (quantizer , quantization_config )
713-
714683 super ().__init__ (
684+ compile_spec ,
715685 module ,
716686 test_data ,
717687 aten_ops ,
718- compile_spec ,
719688 exir_ops ,
720- use_to_edge_transform_and_lower ,
689+ run_on_fvp = run_on_fvp ,
690+ symmetric_io_quantization = symmetric_io_quantization ,
691+ per_channel_quantization = per_channel_quantization ,
692+ a16w8_quantization = a16w8_quantization ,
693+ use_to_edge_transform_and_lower = use_to_edge_transform_and_lower ,
694+ atol = atol ,
695+ rtol = rtol ,
696+ qtol = qtol ,
697+ epsilon = epsilon ,
721698 )
722699
723- self .add_stage (self .tester .quantize , quant_stage , pos = 0 )
724700
725- remove_quant_nodes_stage = (
726- "to_edge_transform_and_lower"
727- if use_to_edge_transform_and_lower
728- else "partition"
729- )
701+ class EthosU85PipelineINT (EthosUPipelineINTBase , Generic [T ]):
702+ """
703+ Lowers a graph to u85 INT TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true.
730704
731- if _has_quantizable_inputs (test_data ):
732- # only add stages if we have quantizable input
733- self .add_stage_after (
734- "quantize" ,
735- self .tester .check ,
736- [
737- "torch.ops.quantized_decomposed.dequantize_per_tensor.default" ,
738- "torch.ops.quantized_decomposed.quantize_per_tensor.default" ,
739- ],
740- suffix = "quant_nodes" ,
741- )
742- self .add_stage_after (
743- remove_quant_nodes_stage ,
744- self .tester .check_not ,
745- [
746- "torch.ops.quantized_decomposed.dequantize_per_tensor.default" ,
747- "torch.ops.quantized_decomposed.quantize_per_tensor.default" ,
748- ],
749- suffix = "quant_nodes" ,
750- )
705+ Attributes:
706+ module: The module which the pipeline is applied to.
707+ test_data: Data used for quantizing and testing the module.
708+ aten_ops: Aten dialect ops expected to be found in the graph after export.
751709
752- if run_on_fvp :
753- self .add_stage (self .tester .serialize )
754- self .add_stage (
755- self .tester .run_method_and_compare_outputs ,
756- atol = atol ,
757- rtol = rtol ,
758- qtol = qtol ,
759- inputs = self .test_data ,
760- )
710+ exir_ops: Exir dialect ops expected to be found in the graph after to_edge if not using
711+ use_edge_to_transform_and_lower.
712+ run_on_fvp: Set to true to test the pte file on a fvp simulator.
713+ use_edge_to_transform_and_lower: Selects between two possible ways of lowering the module.
714+ custom_path : Path to dump intermediate artifacts such as tosa and pte to.
715+ """
716+
717+ def __init__ (
718+ self ,
719+ module : torch .nn .Module ,
720+ test_data : T ,
721+ aten_ops : str | List [str ],
722+ exir_ops : str | Sequence [str ] | None = None ,
723+ run_on_fvp : bool = True ,
724+ symmetric_io_quantization : bool = False ,
725+ per_channel_quantization : bool = True ,
726+ a16w8_quantization : bool = False ,
727+ use_to_edge_transform_and_lower : bool = True ,
728+ custom_path : str | None = None ,
729+ tosa_debug_mode : Optional [ArmCompileSpec .DebugMode ] = None ,
730+ atol : float = 1e-03 ,
731+ rtol : float = 1e-03 ,
732+ qtol : int = 1 ,
733+ epsilon : float = 2 ** - 12 ,
734+ ):
735+ compile_spec = common .get_u85_compile_spec (
736+ custom_path = custom_path ,
737+ tosa_debug_mode = tosa_debug_mode ,
738+ )
739+ super ().__init__ (
740+ compile_spec ,
741+ module ,
742+ test_data ,
743+ aten_ops ,
744+ exir_ops ,
745+ run_on_fvp = run_on_fvp ,
746+ symmetric_io_quantization = symmetric_io_quantization ,
747+ per_channel_quantization = per_channel_quantization ,
748+ a16w8_quantization = a16w8_quantization ,
749+ use_to_edge_transform_and_lower = use_to_edge_transform_and_lower ,
750+ atol = atol ,
751+ rtol = rtol ,
752+ qtol = qtol ,
753+ epsilon = epsilon ,
754+ )
761755
762756
763- class PassPipeline (TOSAPipelineMaker , Generic [T ]):
757+ class PassPipeline (TOSAPipeline , Generic [T ]):
764758 """
765759 Runs single passes directly on an edge_program and checks operators before/after.
766760
@@ -858,7 +852,7 @@ def run(self):
858852 super ().run ()
859853
860854
861- class TransformAnnotationPassPipeline (TOSAPipelineMaker , Generic [T ]):
855+ class TransformAnnotationPassPipeline (TOSAPipeline , Generic [T ]):
862856 """
863857 Runs transform_for_annotation_pipeline passes directly on an exported program and checks output.
864858
@@ -914,7 +908,7 @@ def __init__(
914908 )
915909
916910
917- class QuantizationPipeline (TOSAPipelineMaker , Generic [T ]):
911+ class QuantizationPipeline (TOSAPipeline , Generic [T ]):
918912 """
919913 Runs quantization and checks that appropriate nodes are annotated with an expected
920914 quantization-spec.
@@ -971,7 +965,7 @@ def __init__(
971965 )
972966
973967
974- class OpNotSupportedPipeline (TOSAPipelineMaker , Generic [T ]):
968+ class OpNotSupportedPipeline (TOSAPipeline , Generic [T ]):
975969 """
976970 Runs the partitioner on a module and checks that ops are not delegated to test
977971 SupportedTOSAOperatorChecks.
@@ -1040,7 +1034,7 @@ def __init__(
10401034 self .pop_stage ("to_executorch" )
10411035
10421036
1043- class VgfPipeline (BasePipelineMaker , Generic [T ]):
1037+ class VgfPipeline (BasePipeline , Generic [T ]):
10441038 """
10451039 Lowers a graph based on TOSA spec (with or without quantization) and converts TOSA to VFG.
10461040
0 commit comments