1111from executorch .exir .verification .verifier import EXIREdgeDialectVerifier
1212from torch import Tensor
1313from torch ._export .verifier import Verifier
14+ from torch ._ops import OpOverload
1415from torch .export import ExportedProgram
1516from torch .export .exported_program import ModuleCallEntry , ModuleCallSignature
1617from torch .export .graph_signature import (
@@ -32,12 +33,19 @@ class IrMode(Enum):
3233class ProgramBuilder (GraphBuilder ):
3334 """Utility class to build a program from a graph module."""
3435
35- def __init__ (self , mode : Optional [IrMode ] = None ) -> None :
36+ def __init__ (
37+ self ,
38+ mode : Optional [IrMode ] = None ,
39+ _core_aten_ops_exception_list : Optional [list [OpOverload ]] = None ,
40+ ) -> None :
3641 self .input_specs : list [InputSpec ] = []
3742 self .output_specs : list [OutputSpec ] = []
3843 self .constants : dict [str , Tensor ] = {}
3944 self .state_dict : dict [str , Tensor ] = {}
4045 self .mode : IrMode = mode or IrMode .EXIR
46+ self ._core_aten_ops_exception_list : list [OpOverload ] = (
47+ _core_aten_ops_exception_list or []
48+ )
4149 super ().__init__ ()
4250
4351 def insert_input_spec (
@@ -82,7 +90,11 @@ def get_verifiers(self) -> Optional[list[Verifier]]:
8290 return None
8391 return [
8492 EXIREdgeDialectVerifier (
85- edge_compile_config = EdgeCompileConfig (_check_ir_validity = False ),
93+ edge_compile_config = EdgeCompileConfig (
94+ _check_ir_validity = False ,
95+ _core_aten_ops_exception_list = self ._core_aten_ops_exception_list ,
96+ ),
97+ core_aten_ops_exception_list = self ._core_aten_ops_exception_list ,
8698 class_only = True ,
8799 )
88100 ]
@@ -113,4 +125,7 @@ def get_program(self) -> ExportedProgram:
113125 )
114126
115127 def get_edge_program (self ) -> EdgeProgramManager :
116- return EdgeProgramManager (self .get_program ())
128+ return EdgeProgramManager (
129+ self .get_program (),
130+ core_aten_ops_exception_list = self ._core_aten_ops_exception_list ,
131+ )
0 commit comments