Skip to content

Commit ecf49ca

Browse files
apullinAndrew Pullin
andauthored
Add DecomposeLstmPass for ARM backend (#17140) (#17140)
Summary: Adds a decomposition pass that transforms aten.lstm.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). LSTM cell equations per timestep: i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi) f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf) g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg) o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho) c_t = f_t * c_{t-1} + i_t * g_t h_t = o_t * tanh(c_t) Features: - Multi-layer LSTM support - Bidirectional LSTM support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 8 ) Differential Revision: D92059277 --------- Co-authored-by: Andrew Pullin <pullinandrew@meta.com>
1 parent 841181e commit ecf49ca

8 files changed

Lines changed: 1522 additions & 0 deletions

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .decompose_glu_pass import DecomposeGluPass # noqa
5454
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5555
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
56+
from .decompose_gru_pass import DecomposeGruPass # noqa
5657
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
5758
from .decompose_index_select_to_gather_pass import ( # noqa
5859
DecomposeIndexSelectToGatherPass,
@@ -70,13 +71,15 @@
7071
from .decompose_linear_pass import DecomposeLinearPass # noqa
7172
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
7273
from .decompose_logit_pass import DecomposeLogitPass # noqa
74+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7375
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7476
from .decompose_matmul import DecomposeMatmulPass # noqa
7577
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
7678
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
7779
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
7880
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
7981
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
82+
from .decompose_rnn_pass import DecomposeRnnPass # noqa
8083
from .decompose_round_pass import DecomposeRoundPass # noqa
8184
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
8285
from .decompose_select import DecomposeSelectPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
DecomposeGluPass,
6262
DecomposeGroupedConvPass,
6363
DecomposeGroupNormPass,
64+
DecomposeGruPass,
6465
DecomposeIndexCopyPass,
6566
DecomposeIndexSelectToGatherPass,
6667
DecomposeIndexTensorToGatherPass,
@@ -71,13 +72,15 @@
7172
DecomposeLinearPass,
7273
DecomposeLog1pPass,
7374
DecomposeLogitPass,
75+
DecomposeLstmPass,
7476
DecomposeMaskedFillPass,
7577
DecomposeMatmulPass,
7678
DecomposeMaxPool2dPass,
7779
DecomposeMeanDimPass,
7880
DecomposeNotEqualPass,
7981
DecomposeQuantNodesPass,
8082
DecomposeRemainderPass,
83+
DecomposeRnnPass,
8184
DecomposeRoundPass,
8285
DecomposeScaledDotProductAttentionPass,
8386
DecomposeSelectPass,
@@ -360,6 +363,9 @@ def _tosa_pipeline(
360363
ConvertToClampPass(),
361364
DecomposeTOSAUnsupportedClampPass(),
362365
DecomposeGroupNormPass(),
366+
DecomposeGruPass(),
367+
DecomposeLstmPass(),
368+
DecomposeRnnPass(),
363369
DecomposeLayerNormPass(),
364370
DecomposeVarPass(),
365371
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
@@ -578,6 +584,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
578584
self.add_passes(
579585
[
580586
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
587+
DecomposeGruPass(tfa_pass=True),
588+
DecomposeLstmPass(tfa_pass=True),
589+
DecomposeRnnPass(tfa_pass=True),
581590
DecomposeNotEqualPass(tfa_pass=True),
582591
DecomposeCosineSimilarityPass(tfa_pass=True),
583592
DecomposeGluPass(tfa_pass=True),

0 commit comments

Comments
 (0)