Commit 5c2da73
Add DecomposeLstmPass for ARM backend (#17140)
Summary:
Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).
LSTM cell equations per timestep:
i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
c_t = f_t * c_{t-1} + i_t * g_t
h_t = o_t * tanh(c_t)
Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
Differential Revision: D920592771 parent 01e03c9 commit 5c2da73
5 files changed
Lines changed: 543 additions & 0 deletions
File tree
- backends/arm
- _passes
- test/passes
- tosa
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
71 | 71 | | |
72 | 72 | | |
73 | 73 | | |
| 74 | + | |
74 | 75 | | |
75 | 76 | | |
76 | 77 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| 75 | + | |
75 | 76 | | |
76 | 77 | | |
77 | 78 | | |
| |||
368 | 369 | | |
369 | 370 | | |
370 | 371 | | |
| 372 | + | |
371 | 373 | | |
372 | 374 | | |
373 | 375 | | |
| |||
590 | 592 | | |
591 | 593 | | |
592 | 594 | | |
| 595 | + | |
593 | 596 | | |
594 | 597 | | |
595 | 598 | | |
| |||
0 commit comments