Skip to content

Commit 60d56eb

Browse files
Andrew Pullinfacebook-github-bot
authored andcommitted
Add DecomposeGruPass for ARM backend (#17137)
Summary: Adds a decomposition pass that transforms aten.gru.input into elementary ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat). GRU cell equations per timestep: r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr) z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz) n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn)) h_t = n_t + z_t * (h_{t-1} - n_t) Features: - Multi-layer GRU support - Bidirectional GRU support - With/without bias - batch_first support - Batched gate computation (2 mm ops per timestep instead of 6) Differential Revision: D92058313
1 parent 493e84a commit 60d56eb

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

backends/arm/_passes/decompose_gru_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def _build_direction(
6666
) -> Tuple[List[torch.fx.Node], torch.fx.Node]:
6767
"""Build GRU cell computation for one direction.
6868
69+
Uses aten.linear (matching PyTorch's standard decomposition) instead
70+
of raw mm to avoid ConvertMmToBmmPass issues with quantized tensors.
71+
6972
Returns (timestep_outputs, h_final) where timestep_outputs are
7073
unsqueezed hidden states in forward time order.
7174

0 commit comments

Comments
 (0)