Skip to content

Commit 5c2da73

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeLstmPass for ARM backend (#17140)
Summary: Adds a decomposition pass that transforms aten.lstm.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). LSTM cell equations per timestep: i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi) f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf) g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg) o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho) c_t = f_t * c_{t-1} + i_t * g_t h_t = o_t * tanh(c_t) Features: - Multi-layer LSTM support - Bidirectional LSTM support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 8) Differential Revision: D92059277
1 parent 01e03c9 commit 5c2da73

5 files changed

Lines changed: 543 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .decompose_linear_pass import DecomposeLinearPass # noqa
7272
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
7373
from .decompose_logit_pass import DecomposeLogitPass # noqa
74+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7475
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7576
from .decompose_matmul import DecomposeMatmulPass # noqa
7677
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
DecomposeLinearPass,
7373
DecomposeLog1pPass,
7474
DecomposeLogitPass,
75+
DecomposeLstmPass,
7576
DecomposeMaskedFillPass,
7677
DecomposeMatmulPass,
7778
DecomposeMaxPool2dPass,
@@ -368,6 +369,7 @@ def _tosa_pipeline(
368369
DecomposeTOSAUnsupportedClampPass(),
369370
DecomposeGroupNormPass(),
370371
DecomposeGruPass(),
372+
DecomposeLstmPass(),
371373
DecomposeRnnPass(),
372374
DecomposeLayerNormPass(),
373375
DecomposeVarPass(),
@@ -590,6 +592,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
590592
[
591593
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
592594
DecomposeGruPass(tfa_pass=True),
595+
DecomposeLstmPass(tfa_pass=True),
593596
DecomposeRnnPass(tfa_pass=True),
594597
DecomposeNotEqualPass(tfa_pass=True),
595598
DecomposeCosineSimilarityPass(tfa_pass=True),

0 commit comments

Comments
 (0)