Commit 60d56eb
Add DecomposeGruPass for ARM backend (#17137)
Summary:
Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).
GRU cell equations per timestep:
r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
h_t = n_t + z_t * (h_{t-1} - n_t)
Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)
Differential Revision: D920583131 parent 493e84a commit 60d56eb
1 file changed
Lines changed: 3 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
69 | 72 | | |
70 | 73 | | |
71 | 74 | | |
| |||
0 commit comments