Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_gru_pass import DecomposeGruPass # noqa
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
from .decompose_index_select_to_gather_pass import ( # noqa
DecomposeIndexSelectToGatherPass,
Expand All @@ -70,13 +71,15 @@
from .decompose_linear_pass import DecomposeLinearPass # noqa
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
from .decompose_logit_pass import DecomposeLogitPass # noqa
from .decompose_lstm_pass import DecomposeLstmPass # noqa
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
from .decompose_matmul import DecomposeMatmulPass # noqa
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
from .decompose_rnn_pass import DecomposeRnnPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
DecomposeGluPass,
DecomposeGroupedConvPass,
DecomposeGroupNormPass,
DecomposeGruPass,
DecomposeIndexCopyPass,
DecomposeIndexSelectToGatherPass,
DecomposeIndexTensorToGatherPass,
Expand All @@ -71,13 +72,15 @@
DecomposeLinearPass,
DecomposeLog1pPass,
DecomposeLogitPass,
DecomposeLstmPass,
DecomposeMaskedFillPass,
DecomposeMatmulPass,
DecomposeMaxPool2dPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeQuantNodesPass,
DecomposeRemainderPass,
DecomposeRnnPass,
DecomposeRoundPass,
DecomposeScaledDotProductAttentionPass,
DecomposeSelectPass,
Expand Down Expand Up @@ -365,6 +368,9 @@ def _tosa_pipeline(
ConvertToClampPass(),
DecomposeTOSAUnsupportedClampPass(),
DecomposeGroupNormPass(),
DecomposeGruPass(),
DecomposeLstmPass(),
DecomposeRnnPass(),
DecomposeLayerNormPass(),
DecomposeVarPass(),
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
Expand Down Expand Up @@ -585,6 +591,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_passes(
[
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
DecomposeGruPass(tfa_pass=True),
DecomposeLstmPass(tfa_pass=True),
DecomposeRnnPass(tfa_pass=True),
DecomposeNotEqualPass(tfa_pass=True),
DecomposeCosineSimilarityPass(tfa_pass=True),
DecomposeGluPass(tfa_pass=True),
Expand Down
Loading
Loading