Commit ecf49ca
Summary:
Adds a decomposition pass that transforms aten.lstm.input into
elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).
LSTM cell equations per timestep:
i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
c_t = f_t * c_{t-1} + i_t * g_t
h_t = o_t * tanh(c_t)
Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8 )
Differential Revision: D92059277
---------
Co-authored-by: Andrew Pullin <pullinandrew@meta.com>
1 parent 841181e commit ecf49ca
8 files changed
Lines changed: 1522 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| 56 | + | |
56 | 57 | | |
57 | 58 | | |
58 | 59 | | |
| |||
70 | 71 | | |
71 | 72 | | |
72 | 73 | | |
| 74 | + | |
73 | 75 | | |
74 | 76 | | |
75 | 77 | | |
76 | 78 | | |
77 | 79 | | |
78 | 80 | | |
79 | 81 | | |
| 82 | + | |
80 | 83 | | |
81 | 84 | | |
82 | 85 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
| |||
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| 75 | + | |
74 | 76 | | |
75 | 77 | | |
76 | 78 | | |
77 | 79 | | |
78 | 80 | | |
79 | 81 | | |
80 | 82 | | |
| 83 | + | |
81 | 84 | | |
82 | 85 | | |
83 | 86 | | |
| |||
360 | 363 | | |
361 | 364 | | |
362 | 365 | | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
363 | 369 | | |
364 | 370 | | |
365 | 371 | | |
| |||
578 | 584 | | |
579 | 585 | | |
580 | 586 | | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
581 | 590 | | |
582 | 591 | | |
583 | 592 | | |
| |||
0 commit comments