Skip to content

NXP backend: added support for aten.conv_transpose1 and refactored convolution_converter#19004

Open
novak-vaclav wants to merge 1 commit intopytorch:mainfrom
nxp-upstream:feature/EIEX-681-add-transposed-conv-1d-support
Open

NXP backend: added support for aten.conv_transpose1 and refactored convolution_converter#19004
novak-vaclav wants to merge 1 commit intopytorch:mainfrom
nxp-upstream:feature/EIEX-681-add-transposed-conv-1d-support

Conversation

@novak-vaclav
Copy link
Copy Markdown
Contributor

@novak-vaclav novak-vaclav commented Apr 20, 2026

Summary

Added support for aten.conv_transpose1d by moving functionality from convolution_converter to brand new convert_1d_conv_to2d aten pass, and extending it.

Test plan

tests can be manually run using pytest -c /dev/null backends/nxp/tests/

cc @robert-kalmar @JakeStevens @digantdesai @MartinPavella

Copilot AI review requested due to automatic review settings April 20, 2026 16:34
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19004

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 1 Cancelled Job, 2 Unrelated Failures

As of commit 73e97a6 with merge base 063f9c9 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 20, 2026
@novak-vaclav
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: nxp"

@pytorch-bot pytorch-bot Bot added the release notes: nxp Changes to the NXP Neutron backend delegate label Apr 20, 2026
@novak-vaclav
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "module: nxp"

@pytorch-bot pytorch-bot Bot added the module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ label Apr 20, 2026
@novak-vaclav novak-vaclav marked this pull request as draft April 20, 2026 16:37
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds NXP backend support for aten.conv_transpose1d by moving 1D-to-2D convolution lowering out of the IR converter and into a dedicated ATen graph rewrite pass, plus adjusts quantization handling for grouped transposed convolutions.

Changes:

  • Introduce ConvertConv1dToConv2dPass to rewrite aten.conv1d / aten.conv_transpose1d into 2D equivalents via unsqueeze/conv2d(or conv_transpose2d)/squeeze.
  • Remove 1D-convolution handling from the TFLite convolution_converter and enable the new pass in the default Neutron ATen pass pipeline.
  • Update quantizer patterns/utilities to correctly derive bias qparams for grouped conv_transpose2d and fix per-channel axis handling; add comprehensive tests for the new pass.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
backends/nxp/aten_passes/convert_1d_conv_to_2d.py New ATen pass converting 1D conv/transposed conv to 2D form with shape/meta propagation.
backends/nxp/aten_passes/neutron_aten_pass_manager.py Registers the new pass in the default Neutron ATen pass sequence.
backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py Removes 1D convolution conversion logic; now expects only 2D weights rank.
backends/nxp/quantizer/utils.py Adds helper to “pad”/repeat weight scales when deriving bias qparams for grouped transposed conv.
backends/nxp/quantizer/patterns.py Drops 1D conv patterns; updates ConvTranspose2d quantization (bias qparams + correct per-channel axis).
backends/nxp/quantizer/neutron_quantizer.py Removes the Conv1dPattern registration (since 1D conv is rewritten earlier).
backends/nxp/tests/test_convert_1d_conv_to_2d.py New test suite covering conv1d + conv_transpose1d rewrite and full pipeline delegation.
backends/nxp/tests/models.py Updates Conv1d test module API and adds ConvTranspose1d + runtime-weight conv1d models for testing.
backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py Removes prior conv1d conversion tests (superseded by new pass tests).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/nxp/tests/models.py
Comment thread backends/nxp/tests/models.py
Comment thread backends/nxp/aten_passes/convert_1d_conv_to_2d.py
Comment thread backends/nxp/aten_passes/convert_1d_conv_to_2d.py Outdated
@MartinPavella MartinPavella self-requested a review April 21, 2026 06:50
@novak-vaclav novak-vaclav force-pushed the feature/EIEX-681-add-transposed-conv-1d-support branch from 8568135 to 73e97a6 Compare April 29, 2026 07:58
@novak-vaclav novak-vaclav marked this pull request as ready for review April 29, 2026 07:59
Copilot AI review requested due to automatic review settings April 29, 2026 07:59
@novak-vaclav
Copy link
Copy Markdown
Contributor Author

Refactored the whole solution. Solved cases where there was batch_norm after conv1d, since insertion of squeeze prevented batch_norm fusion with the conv.
I also thought that converting conv1d to conv2d with bmm behind it would break the bmm, but it should not happen.

@novak-vaclav
Copy link
Copy Markdown
Contributor Author

Please review @MartinPavella and @StrycekSimon. Thank you 😊

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/nxp/tests/models.py
Comment thread backends/nxp/tests/models.py
Comment thread backends/nxp/tests/test_quantizer.py
Comment thread backends/nxp/aten_passes/convert_1d_conv_to_2d.py
Comment thread backends/nxp/aten_passes/convert_1d_conv_to_2d.py
Copy link
Copy Markdown
Collaborator

@MartinPavella MartinPavella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
I only have cosmetic suggestions and questions. Perhaps for future PRs you could split the implementation into multiple commits. Ultimately it would still get squashed to 1, but it would make the review easier as the commit messages would explain the motivation for some changes. Here I'm really not sure why some changes were made.

Also Copilot had some good comments, so please address them too if you see fit.

│ (1D/transposed 1D) │ ────────────────► │ (2D/transposed 2D) │
└────────────┬───────────┘ with └───────────┬───────────┘
│ │
[N, C2, 1, H] [N, C2, 1, H]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The left side should probably be [N, C2, H].


return node.meta["val"].dtype if hasattr(node, "meta") else node.dtype

def _convert_w_node_to_static_attr(self, node: Node):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the method name should mention that the node is reshaped, as if I understand correctly that is the main purpose of the method. Converting to static attribute seems like just the means to get the reshaped static parameter node.


with FakeTensorMode() as mode:
fake_node_args = self._create_fake_tensor_for_node_args(node_args, mode)
output = bn_target(*fake_node_args, *scalar_args)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍🏻

from torch.fx.passes.infra.pass_base import PassBase, PassResult


Conv1dArgs = tuple[Node, Node, (Node | None), list[int], list[int], list[int], int]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Comments (or type aliases) explaining which parameter is which would be nice.

with self.graph_module.graph.inserting_after(bn_2d_node):
squeeze_target = torch.ops.aten.squeeze.dim

out_sq_args = (bn_2d_node, 2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: When you created the unsqueeze before (and also the other squeeze in the else branch) you used index -2. Perhaps using a positive (or negative) index consistently for all cases would be clearer. This way there is a bit more ambiguity.

def get_padded_bias_qparams(
obs_or_fqs: List[ObserverOrFakeQuantize],
out_channels: int | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name doesn't explain what the function does. So I think some docstring here explaining it is specifically designed for transpose conv would be useful. Or perhaps the function name could mention the transpose conv?

@@ -105,31 +96,6 @@ def test_convert_bmm__unsupported_shape(input_shape_x1, input_shape_x2, use_qat)
assert graph_contains_any_of_ops(delegated_ep.graph, [Bmm])


def test_convert_bmm__unsupported_dim_order(mocker, use_qat):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this test removed?

@@ -112,7 +112,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool):
module, tuple(input_shape)
).exported_program()

assert len(edge_program.graph.nodes) == 15
assert len(edge_program.graph.nodes) == 21
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did the number of nodes change?
Is our backend still performing as expected?

@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("dilation", [2, 1])
@pytest.mark.parametrize("kernel_size", [(1,), (3,)])
def test_conv1d_quant_conversion(bias, stride, dilation, kernel_size, mocker, use_qat):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these tests removed?

example_input = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
convert_run_compare(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally for a full pipeline test you should run the whole delegated model using nxp_executor_runner (the other convert_run_compare a.k.a. the lower_run_compare since #19112).
It's not a big deal here and no change is required, but for future PRs, please only use lower_run_compare. Comparing NeutronIR inference output with edge is now a thing of the past :).

Copy link
Copy Markdown
Collaborator

@StrycekSimon StrycekSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, handle the 2D inputs or at least mark them as unsupported and create an issue.


@staticmethod
def _is_batch_norm(node: Node) -> bool:
return node.target == torch.ops.aten.batch_norm.default
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: There already is that function implemented in graph_utils.py.

@pytest.mark.parametrize(
"input_shape, kernel_size, stride, padding, dilation, groups, bias",
[
pytest.param((3, 7, 23), 3, 1, 0, 1, 1, True, id="All default."),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try adding input of shape (3, 7), which is a valid value for Conv1d, and it will fail. BatchNorm always works on dim=1. By unsqueezing tensor or shape (N, L), you create a (N, 1, L) one, causing the batch norm to not work on the L dimension (as before), but on the the newly added dimension of size 1. This causes crash because of shape missmatch. This should be either handled differently or marked as unsupported.


return fake_node_args

def _create_batch_norm_2d_node(self, *bn_args):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: You most probably don't even need to create a new node as batch norm is handled by a single Aten op for all cases of input dim.

np.random.seed(23)


AtenConv1d = torch.ops.aten.conv1d.default
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ release notes: nxp Changes to the NXP Neutron backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants