Skip to content

Commit 936fc09

Browse files
Arm backend: Use TOSA-context when adding passes
Make sure to be inside TosaLoweringContext when adding/initializing passes. This is needed as passes may try to set self.shape_env using get_context_shape_env(). This patch also uses _get_shape_env_from_gm instead of graph_module.shape_env as we cannot rely on the shape_env being set in the graph_module even though sym-ints are present in the graph. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: If515e12af3382909320585d1a1064281203cbda9
1 parent 87e65ac commit 936fc09

1 file changed

Lines changed: 83 additions & 78 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 83 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
from executorch.exir import ExportedProgram
160160
from executorch.exir.pass_base import ExportPass
161161
from executorch.exir.pass_manager import PassManager
162+
from torch._export.utils import _get_shape_env_from_gm
162163
from torch.fx import GraphModule
163164
from torch.fx.passes.infra.pass_base import PassResult
164165
from torch.nn.modules import Module
@@ -335,10 +336,12 @@ def add_passes(self, passes: Sequence[ExportPass | None]):
335336
if p is not None:
336337
self.add_pass(p)
337338

339+
def _tosa_context(self, graph_module: GraphModule) -> TosaLoweringContext:
340+
shape_env = _get_shape_env_from_gm(graph_module)
341+
return TosaLoweringContext(self.tosa_spec, shape_env)
342+
338343
def _transform(self, graph_module: GraphModule):
339-
shape_env = graph_module.shape_env
340-
with TosaLoweringContext(self.tosa_spec, shape_env):
341-
return self(graph_module).graph_module
344+
return self(graph_module).graph_module
342345

343346
def add_pass(self, pipeline_pass):
344347
if type(pipeline_pass) in self._skip_pass_types:
@@ -532,87 +535,89 @@ def transform_to_backend_pipeline(
532535
f"No pass pipeline found for TOSA specification: {self.tosa_spec}"
533536
)
534537

535-
return self._tosa_pipeline(exported_program, graph_module)
538+
with self._tosa_context(graph_module):
539+
return self._tosa_pipeline(exported_program, graph_module)
536540

537541
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
538-
# Preprocessing passes
539-
self.add_pass(RemoveGraphAssertsPass(tfa_pass=True))
540-
self.add_pass(ConstantFoldingPass())
541-
542-
# Transformation passes (pre scalar -> tensor)
543-
self.add_passes(
544-
[
545-
DecomposeIndexCopyPass(tfa_pass=True),
546-
DecomposeSelectScatterPass(tfa_pass=True),
547-
DecomposeSliceScatterPass(tfa_pass=True),
548-
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
549-
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
550-
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
551-
DecomposeEmbeddingPass(tfa_pass=True),
552-
DecomposeScaledDotProductAttentionPass(tfa_pass=True),
553-
DecomposeRoundPass(tfa_pass=True),
554-
DecomposeLogitPass(tfa_pass=True),
555-
PromoteBoolOperandsPass(tfa_pass=True),
556-
DecomposeSignPass(tfa_pass=True),
557-
DecomposeTrilPass(tfa_pass=True),
558-
DecomposeAddmmPass(tfa_pass=True),
559-
DecomposeRemainderPass(tfa_pass=True),
560-
DecomposeFloorDividePass(tfa_pass=True),
561-
DecomposeDivTensorModePass(tfa_pass=True),
562-
DecomposeWhereScalarOtherPass(tfa_pass=True),
563-
DecomposeEinsumPass(tfa_pass=True),
564-
RewriteInplaceArithmeticPass(tfa_pass=True),
565-
DecomposeAddSubAlphaPass(tfa_pass=True),
566-
DecomposeLeakyReLUPass(tfa_pass=True),
567-
DecomposeGroupNormPass(tfa_pass=True),
568-
DecomposeLayerNormPass(tfa_pass=True),
569-
DecomposeVarPass(tfa_pass=True),
570-
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
571-
]
572-
)
542+
with self._tosa_context(graph_module):
543+
# Preprocessing passes
544+
self.add_pass(RemoveGraphAssertsPass(tfa_pass=True))
545+
self.add_pass(ConstantFoldingPass())
546+
547+
# Transformation passes (pre scalar -> tensor)
548+
self.add_passes(
549+
[
550+
DecomposeIndexCopyPass(tfa_pass=True),
551+
DecomposeSelectScatterPass(tfa_pass=True),
552+
DecomposeSliceScatterPass(tfa_pass=True),
553+
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
554+
ConvertInt64OutputOpsToInt32Pass(tfa_pass=True),
555+
InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True),
556+
DecomposeEmbeddingPass(tfa_pass=True),
557+
DecomposeScaledDotProductAttentionPass(tfa_pass=True),
558+
DecomposeRoundPass(tfa_pass=True),
559+
DecomposeLogitPass(tfa_pass=True),
560+
PromoteBoolOperandsPass(tfa_pass=True),
561+
DecomposeSignPass(tfa_pass=True),
562+
DecomposeTrilPass(tfa_pass=True),
563+
DecomposeAddmmPass(tfa_pass=True),
564+
DecomposeRemainderPass(tfa_pass=True),
565+
DecomposeFloorDividePass(tfa_pass=True),
566+
DecomposeDivTensorModePass(tfa_pass=True),
567+
DecomposeWhereScalarOtherPass(tfa_pass=True),
568+
DecomposeEinsumPass(tfa_pass=True),
569+
RewriteInplaceArithmeticPass(tfa_pass=True),
570+
DecomposeAddSubAlphaPass(tfa_pass=True),
571+
DecomposeLeakyReLUPass(tfa_pass=True),
572+
DecomposeGroupNormPass(tfa_pass=True),
573+
DecomposeLayerNormPass(tfa_pass=True),
574+
DecomposeVarPass(tfa_pass=True),
575+
DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True),
576+
]
577+
)
573578

574-
# Scalars -> tensors
575-
self.add_passes(
576-
[
577-
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
578-
ScalarsToAttributePass(tfa_pass=True),
579-
ControlFlowConstInlinePass(tfa_pass=True),
580-
]
581-
)
579+
# Scalars -> tensors
580+
self.add_passes(
581+
[
582+
ReplaceScalarWithTensorByProfilePass(tfa_pass=True),
583+
ScalarsToAttributePass(tfa_pass=True),
584+
ControlFlowConstInlinePass(tfa_pass=True),
585+
]
586+
)
582587

583-
# Transformation passes (post scalar removal)
584-
self.add_passes(
585-
[
586-
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
587-
DecomposeGruPass(tfa_pass=True),
588-
DecomposeLstmPass(tfa_pass=True),
589-
DecomposeRnnPass(tfa_pass=True),
590-
DecomposeNotEqualPass(tfa_pass=True),
591-
DecomposeCosineSimilarityPass(tfa_pass=True),
592-
DecomposeGluPass(tfa_pass=True),
593-
DecomposeDivPass(tfa_pass=True),
594-
DecomposeLinalgVectorNormPass(tfa_pass=True),
595-
DecomposeSqrtPass(tfa_pass=True),
596-
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
597-
DecomposeAvgPool2dPass(tfa_pass=True),
598-
DecomposeSoftmaxPass(
599-
tfa_pass=True,
600-
),
601-
ConvertMinMaxPass(tfa_pass=True),
602-
AccumulateIndexPutPass(tfa_pass=True),
603-
DecomposeMatmulPass(tfa_pass=True),
604-
]
605-
)
588+
# Transformation passes (post scalar removal)
589+
self.add_passes(
590+
[
591+
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
592+
DecomposeGruPass(tfa_pass=True),
593+
DecomposeLstmPass(tfa_pass=True),
594+
DecomposeRnnPass(tfa_pass=True),
595+
DecomposeNotEqualPass(tfa_pass=True),
596+
DecomposeCosineSimilarityPass(tfa_pass=True),
597+
DecomposeGluPass(tfa_pass=True),
598+
DecomposeDivPass(tfa_pass=True),
599+
DecomposeLinalgVectorNormPass(tfa_pass=True),
600+
DecomposeSqrtPass(tfa_pass=True),
601+
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
602+
DecomposeAvgPool2dPass(tfa_pass=True),
603+
DecomposeSoftmaxPass(
604+
tfa_pass=True,
605+
),
606+
ConvertMinMaxPass(tfa_pass=True),
607+
AccumulateIndexPutPass(tfa_pass=True),
608+
DecomposeMatmulPass(tfa_pass=True),
609+
]
610+
)
606611

607-
# Postprocessing passes
608-
self.add_passes(
609-
[
610-
ReplaceInfAndLimitValuesPass(tfa_pass=True),
611-
DecomposeMaskedFillPass(tfa_pass=True),
612-
]
613-
)
612+
# Postprocessing passes
613+
self.add_passes(
614+
[
615+
ReplaceInfAndLimitValuesPass(tfa_pass=True),
616+
DecomposeMaskedFillPass(tfa_pass=True),
617+
]
618+
)
614619

615-
return self._transform(graph_module)
620+
return self._transform(graph_module)
616621

617622
def __call__(self, module: Module) -> PassResult:
618623
try:

0 commit comments

Comments
 (0)