Skip to content

Commit 0c208b9

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeLstmPass for ARM backend (#17140)
Summary: Adds a decomposition pass that transforms aten.lstm.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). LSTM cell equations per timestep: i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi) f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf) g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg) o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho) c_t = f_t * c_{t-1} + i_t * g_t h_t = o_t * tanh(c_t) Features: - Multi-layer LSTM support - Bidirectional LSTM support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 8) --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace) Differential Revision: D92059277
1 parent 1c44a77 commit 0c208b9

File tree

4 files changed

+531
-0
lines changed

4 files changed

+531
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from .decompose_linear_pass import DecomposeLinearPass # noqa
6565
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
6666
from .decompose_logit_pass import DecomposeLogitPass # noqa
67+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
6768
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
6869
from .decompose_matmul import DecomposeMatmulPass # noqa
6970
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
DecomposeLinearPass,
6666
DecomposeLog1pPass,
6767
DecomposeLogitPass,
68+
DecomposeLstmPass,
6869
DecomposeMaskedFillPass,
6970
DecomposeMatmulPass,
7071
DecomposeMaxPool2dPass,
@@ -239,6 +240,7 @@ def _tosa_pipeline(
239240
DecomposeTOSAUnsupportedClampPass(),
240241
DecomposeGroupNormPass(),
241242
DecomposeGruPass(),
243+
DecomposeLstmPass(),
242244
DecomposeRnnPass(),
243245
DecomposeLayerNormPass(),
244246
DecomposeVarPass(),
@@ -429,6 +431,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
429431
DecomposeAddSubAlphaPass(tfa_pass=True),
430432
DecomposeGroupNormPass(tfa_pass=True),
431433
DecomposeGruPass(tfa_pass=True),
434+
DecomposeLstmPass(tfa_pass=True),
432435
DecomposeRnnPass(tfa_pass=True),
433436
DecomposeLayerNormPass(tfa_pass=True),
434437
DecomposeVarPass(tfa_pass=True),

0 commit comments

Comments
 (0)