Skip to content

Commit e0370ba

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeLstmPass for ARM backend (#17140)
Summary: Adds a decomposition pass that transforms aten.lstm.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). LSTM cell equations per timestep: i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi) f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf) g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg) o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho) c_t = f_t * c_{t-1} + i_t * g_t h_t = o_t * tanh(c_t) Features: - Multi-layer LSTM support - Bidirectional LSTM support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 8) Differential Revision: D92059277
1 parent e989803 commit e0370ba

4 files changed

Lines changed: 538 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .decompose_linear_pass import DecomposeLinearPass # noqa
7272
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
7373
from .decompose_logit_pass import DecomposeLogitPass # noqa
74+
from .decompose_lstm_pass import DecomposeLstmPass # noqa
7475
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
7576
from .decompose_matmul import DecomposeMatmulPass # noqa
7677
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
DecomposeLinearPass,
7373
DecomposeLog1pPass,
7474
DecomposeLogitPass,
75+
DecomposeLstmPass,
7576
DecomposeMaskedFillPass,
7677
DecomposeMatmulPass,
7778
DecomposeMaxPool2dPass,
@@ -363,6 +364,7 @@ def _tosa_pipeline(
363364
DecomposeTOSAUnsupportedClampPass(),
364365
DecomposeGroupNormPass(),
365366
DecomposeGruPass(),
367+
DecomposeLstmPass(),
366368
DecomposeRnnPass(),
367369
DecomposeLayerNormPass(),
368370
DecomposeVarPass(),
@@ -583,6 +585,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
583585
[
584586
NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True),
585587
DecomposeGruPass(tfa_pass=True),
588+
DecomposeLstmPass(tfa_pass=True),
586589
DecomposeRnnPass(tfa_pass=True),
587590
DecomposeNotEqualPass(tfa_pass=True),
588591
DecomposeCosineSimilarityPass(tfa_pass=True),

0 commit comments

Comments
 (0)