Skip to content

Add DecomposeLstmPass for ARM backend (#17140)#17140

Open
apullin wants to merge 3 commits intopytorch:mainfrom
apullin:export-D92059277
Open

Add DecomposeLstmPass for ARM backend (#17140)#17140
apullin wants to merge 3 commits intopytorch:mainfrom
apullin:export-D92059277

Conversation

@apullin
Copy link
Copy Markdown
Contributor

@apullin apullin commented Feb 3, 2026

Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
c_t = f_t * c_{t-1} + i_t * g_t
h_t = o_t * tanh(c_t)

Features:

  • Multi-layer LSTM support
  • Bidirectional LSTM support
  • With/without bias
  • batch_first support
  • Batched gate computation (2 mm ops per timestep instead of 8)

Differential Revision: D92059277

@apullin apullin requested a review from digantdesai as a code owner February 3, 2026 07:35
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17140

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Awaiting Approval, 2 Cancelled Jobs

As of commit 5c2da73 with merge base 3466332 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 3, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Feb 3, 2026

@apullin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92059277.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 3, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:
Pull Request resolved: pytorch#17140

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
 ---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:
Pull Request resolved: pytorch#17140

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
 ---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@apullin apullin force-pushed the export-D92059277 branch 3 times, most recently from bf1d013 to f57bc19 Compare February 3, 2026 23:23
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@apullin apullin force-pushed the export-D92059277 branch 3 times, most recently from bc86171 to 62726e7 Compare February 3, 2026 23:55
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
apullin pushed a commit to apullin/executorch that referenced this pull request Feb 3, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@zingo zingo added partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm ciflow/trunk labels Feb 6, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 6, 2026

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk label Feb 6, 2026
@zingo zingo changed the title Add DecomposeLstmPass for ARM backend Arm backend: Add DecomposeLstmPass Feb 6, 2026
pytorch-bot bot pushed a commit that referenced this pull request Feb 6, 2026
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)
---
> Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/)
[Confucius Session](https://www.internalfb.com/confucius?host=62602.od.fbinfra.net&port=8086&tab=Chat&session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&entry_name=Code+Assist), [Trace](https://www.internalfb.com/confucius?session_id=e1d1ac52-0014-11f1-9d55-75b7d4e71d8a&tab=Trace)

Differential Revision: D92059277
@gggekov
Copy link
Copy Markdown
Collaborator

gggekov commented Feb 6, 2026

What is the reason to decompose LSTM in the Arm backend rather than let torch.export.export decompose the LSTM ?

@gggekov
Copy link
Copy Markdown
Collaborator

gggekov commented Feb 6, 2026

Never mind- i see the torch.nn.LSTM is not decomposed in the torch.export.export as I thought initially.

Copy link
Copy Markdown
Contributor Author

@apullin apullin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mansnils Trying to clear out the previous requested changes related to MLETORCH-1266, but that has long since been settled by others' PRs. Do I have to do anything from my end?

Copy link
Copy Markdown
Collaborator

@gggekov gggekov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR and fixing the CI ! It's a great PR and would be really nice to add support for torch.nn.lstm/rnn/gru! I have added comments in the unit tests, I have two overall comments.

  1. Could you please use TosaPipelineFP and TosaPipelineINT rather than the PassPipeline to test the decompositions for the three operators ?

  2. We have very strict naming convention for how to name the unit tests. We are working on loosening the convention, but right now, if a unit test does not respect our naming convention, it is not picked up by the CI. I've added comments how to name the various tests if you keep them in backends/arm/test/passes.
    IMO, it would make more sense to move the tests under backends/arm/test/ops, that's where we test most operators that we decompose in transform_for_annotationn(e.g. test_adaptive_avg_pool2d.py, test_gelu.py, test_layer_norm.py, etc). In the case of LSTM, we also already have a unit test in executorch/backends/arm/test/models/test_lstm_arm.py for the rnn.LSTM, you can create another model with torch.nn.lstm and amend that file if you go down the route of moving the tests.

@apullin
Copy link
Copy Markdown
Contributor Author

apullin commented Mar 26, 2026

@gggekov The renaming is done.

And: I redid the tests to use TosaPipelineFP & TosaPipelineINT for all all three layer types (RNN, GRU, LSTM).

BUT: The LSTM TosaPipeline tests are currently skipped due to a TOSA backend shape mismatch on h_n/c_n (our decomposition pass outputs correct shapes, verified by the unit tests — the extra dimension is introduced during TOSA lowering). I'm pretty sure it's a TOSA backend issue. I can attempt to fix this directly as another commit, but that is some creep. Unless I am wrong about that as the failure origin.

I also tried adding EthosU55PipelineINT / EthosU85PipelineINT tests, and they ran fine locally, but GH CI doesn't seem to have the ethos-u-vela package installed?? So they fail there despite working locally.

@gggekov
Copy link
Copy Markdown
Collaborator

gggekov commented Mar 27, 2026

Hi @apullin ,
I am also getting a failure for the test_decompose_lstm_tosa_FP_e2e test case

FAILED backends/arm/test/passes/test_decompose_lstm_pass.py::test_decompose_lstm_tosa_FP_e2e - ValueError: Output needs to be of same shape: torch.Size([1, 1, 2, 20]) != torch.Size([1, 2, 20])

What is the reason for the shape mismatch for TOSA FP? Normally, when lowering an operator, the hard part is to get the numerics right for TOSA INT. In your case, the numerics for TOSA INT seem fine, but you get shape mismatch for TOSA FP. I wonder if something in another pass, run before the DecomposeLstmPass, doesn't cause the shape mismatch. That's why it's important to test via TosaPipelineINT & TosaPipelineFP because you run the full pipeline rather than a single pass in isolation.

Could you also add TosaPipelineINT & TosaPipelineFP tests for the other cases ( LSTM(bidirectional=True) , LSTM(bias=False), LSTM(num_layers=2)) ?

@gggekov
Copy link
Copy Markdown
Collaborator

gggekov commented Mar 27, 2026

Could you also add TosaPipelineINT & TosaPipelineFP tests for the other cases ( LSTM(bidirectional=True) , LSTM(bias=False), LSTM(num_layers=2)) ?

This applies also to the RNN & GRU test cases as well. It would be useful if you test the various attributes of RNN & GRU with TosaPipelineFP and TosaPipelineINT in order to make sure we run the full pipeline rather than just one pass. That's how we test all the operators we support via transform_for_annotation.

One idea about the shape mismatch for LSTM FP- it's possible that you may need to run the pass at a slightly different place in the arm_pass_manager.py - i did a few tests locally running the DecomposeLstmPass earlier in _tosa_pipeline, although i kept getting shape mismatch for TOSA FP :|

@apullin
Copy link
Copy Markdown
Contributor Author

apullin commented Mar 30, 2026

@gggekov ok, I THINK I have this down, now. Updated all three PRs with the following changes:

All three passes (GRU, RNN, LSTM):

  • Added graph_module = super().call(graph_module).graph_module after recompile() to re-propagate FakeTensor shapes after decomposition. This was the root cause of the TOSA shape mismatch on LSTM h_n/c_n outputs.

Tests:

  • All PassPipeline tests replaced with TosaPipelineFP + TosaPipelineINT across all three layer types.
  • LSTM tests are no longer skipped — the shape fix plus a workaround (run_transform_for_annotation_pipeline before to_edge) resolves the extra-dimension
    issue. All 8 LSTM tests (4 FP + 4 INT) now run end-to-end through TosaPipeline.
  • GRU: 8 tests (4 configs × FP + INT)
  • RNN: 10 tests (5 configs × FP + INT)
  • LSTM: 8 tests (4 configs × FP + INT)

All tests pass in a linux environment. But we will have to see if CI agrees with that.

Copy link
Copy Markdown
Collaborator

@gggekov gggekov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @apullin ,
Thank you for updating the unit tests to use TosaPipelineFP/TosaPipelineINT! I think the PR is going in the right direction.

My main comment is that currently, your 3 unit tests test_decompose_<rnn/lstm/gru>_pass.py only work for TOSA INT(not sure if that is intentional?). However, for the TosaPipelineFP unit tests, your decomposition is never called. You are currently working around that in a slightly strange way by manually calling transform_for_annotation after the export from the ArmTester, but that is masking the underlying problem.

I think the proper fix is to make the pass decomposition include both EdgeIR ops(currently missing in your PR) and the ATen ops(that's what you currently do). See for example here (https://github.com/pytorch/executorch/blob/main/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py#L23). In this way, for the TOSA FP case, you would use the EdgeIR decomposition and for the TOSA INT case, you'd use the ATen decomposition.

Why do you need the EdgeIR decomposition for the TOSA FP case ? Have a look at https://github.com/pytorch/executorch/blob/main/exir/program/_program.py#L1384, insert a breakpoint/print statement before and after the call to edge_manager = _gen_edge_manager_for_partitioners(....)
Do print(aten_programs['forward']) on line 1383 and print(edge_manager.exported_program("forward")) on line 1387. On line 1383, you still have the torch.ops.aten.lstm.input that you decompose. However, on line 1387, you no longer have the torch.ops.aten.lstm, you have EdgeIR ops instead. Note that this code is called before our backend has been called. Under the hood, the edge_manager = _gen_edge_manager_for_partitioners(....) API is calling the from torch._decomp import get_decompositions API to do the decomposition from ATen to EdgeIR(i mentioned that API at the very start of the code review a few weeks ago). I think for TosaPipelineFP, because the torch.ops.aten.lstm.input has been replaced with EdgeIR ops, your pipeline is never called and the unit test is passing relying on the decompositions taking place before our backend has been reached. We'd like to have consistency between TOSA FP & TOSA INT - have the same decomposition for both profiles. For other passes decomposed in transform_for_annotation, the TOSA FP case is managed via the EdgeIR operators.

Two more comments:

  • I notice in the first commit for the GRU, you add test_decompose_recurrent_tosa_pipelines.py and then in the last commit(LSTM), you remove that file. Is it not possible to remove it altogether from the first commit ?
  • Do you mind to do $ pip install -r requirements-lintrunner.txt & lintrunner init and then run the ./backends/arm/scripts/pre-commit script to run the lintrunner ? I had to manually fix a few minor errors to get your PR in the internal PR.

Well done, thanks for pushing with that!

@gggekov
Copy link
Copy Markdown
Collaborator

gggekov commented Apr 1, 2026

Hi @apullin ,
I had a second thought - i think for the TOSA FP case, it will be hard to do your own decomposition given that the op is already decomposed by ExecuTorch via the from torch._decomp import get_decompositions API before our backend is reached as per my comment above. I suggest to just remove the the _add_lstm_workaround function, and instead just mark the TOSA FP tests as expected failures with

@pytest.mark.skip(reason="Shape mismatch for TOSA FP")

and add appropriate comment that your pass is only handling the TOSA INT case. For TOSA INT, you shouldn't need the _add_lstm_workaround function.

It would be interesting to understand where is the shape mismatch stemming from for the TOSA FP case, maybe the from torch._decomp import get_decompositions API decomposes the op in a way that results in a slightly different shape compared to running the torch.nn.LSTM in eager mode, or maybe we are introducing an extra dimension somewhere in our FP passes by accident. I am not sure. I suggest to proceed the PR just managing the TOSA INT case, which should make it possible to run the LSTM on the Ethos-U85.

@apullin
Copy link
Copy Markdown
Contributor Author

apullin commented Apr 2, 2026

@gggekov done, updated across all three commits:

  • Removed _add_lstm_workaround and test_decompose_recurrent_tosa_pipelines.py
  • All tests use TosaPipelineFP / TosaPipelineINT (basic, bidirectional, no_bias, multilayer)
  • LSTM FP tests skipped but present, INT is still faithfully tested
  • No feature smear between commits

Andrew Pullin added 3 commits April 6, 2026 09:51
Summary:

Adds a decomposition pass that transforms aten.gru.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

GRU cell equations per timestep:
    r_t = sigmoid(x_t @ W_ir.T + b_ir + h_{t-1} @ W_hr.T + b_hr)
    z_t = sigmoid(x_t @ W_iz.T + b_iz + h_{t-1} @ W_hz.T + b_hz)
    n_t = tanh(x_t @ W_in.T + b_in + r_t * (h_{t-1} @ W_hn.T + b_hn))
    h_t = n_t + z_t * (h_{t-1} - n_t)

Features:
- Multi-layer GRU support
- Bidirectional GRU support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 6)

Differential Revision: D92058313
Summary:

Adds a decomposition pass that transforms aten.rnn_tanh.input and
aten.rnn_relu.input into elementary ops supported by TOSA.

RNN cell equation per timestep:
    h_t = activation(x_t @ W_ih.T + b_ih + h_{t-1} @ W_hh.T + b_hh)

where activation is tanh (rnn_tanh) or relu (rnn_relu).

Features:
- Multi-layer RNN support
- Bidirectional RNN support
- With/without bias
- batch_first support
- Both tanh and relu nonlinearities

Differential Revision: D92059152
Summary:

Adds a decomposition pass that transforms aten.lstm.input into elementary
ops supported by TOSA (matmul, sigmoid, tanh, mul, add, slice, cat).

LSTM cell equations per timestep:
    i_t = sigmoid(x_t @ W_ii.T + b_ii + h_{t-1} @ W_hi.T + b_hi)
    f_t = sigmoid(x_t @ W_if.T + b_if + h_{t-1} @ W_hf.T + b_hf)
    g_t = tanh(x_t @ W_ig.T + b_ig + h_{t-1} @ W_hg.T + b_hg)
    o_t = sigmoid(x_t @ W_io.T + b_io + h_{t-1} @ W_ho.T + b_ho)
    c_t = f_t * c_{t-1} + i_t * g_t
    h_t = o_t * tanh(c_t)

Features:
- Multi-layer LSTM support
- Bidirectional LSTM support
- With/without bias
- batch_first support
- Batched gate computation (2 mm ops per timestep instead of 8)

Differential Revision: D92059277
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants