Skip to content

Commit 9c2246e

Browse files
authored
Arm backend: Stop removing kwargs in quantizer (#16931)
torchao now allows non-tensor kwargs when quantizing. Currently, I am not aware of any tensor kwargs so it is probably better to crash and handle it if we run into them instead of silently not quantizing them. Add a test to make sure kwargs actually do survive quantization.
1 parent 3ff3082 commit 9c2246e

2 files changed

Lines changed: 53 additions & 13 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -840,16 +840,3 @@ def annotate_graph( # type: ignore[return]
840840
_annotate_output(node, quant_properties.quant_output)
841841

842842
mark_node_as_annotated(node) # type: ignore[attr-defined]
843-
844-
# Quantization does not allow kwargs for some reason.
845-
# Remove from ops we know have and where we know it does not break anything.
846-
if node.target in [
847-
torch.ops.aten.full_like.default,
848-
torch.ops.aten.full.default,
849-
torch.ops.aten.full,
850-
torch.ops.aten.fill_.Scalar,
851-
torch.ops.aten.scalar_tensor.default,
852-
torch.ops.aten.zeros.default,
853-
torch.ops.aten.ones.default,
854-
]:
855-
node.kwargs = {}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
12+
from executorch.backends.test.harness.stages.stage import StageType
13+
14+
input_t1 = Tuple[torch.Tensor, int]
15+
16+
exir_op = "executorch_exir_dialects_edge__ops_aten_full_default"
17+
18+
19+
class FullLike(torch.nn.Module):
20+
"""Since full_like is replaced with full, we only need to test on reference model, not FVP."""
21+
22+
test_parameters = {
23+
"full_like_int_val": lambda: (torch.randn(2, 2, 2, 2) * 50, 3),
24+
"full_like_float_val": lambda: (torch.randn(2, 4, 5, 2) * 50, 3.2),
25+
}
26+
27+
def forward(self, input_tensor: torch.Tensor, value):
28+
# Our backend can't handle tensors without users, which input_tensor doesn't have
29+
# when the full_like is converted to a full. Therefore involve it in the output.
30+
return input_tensor + torch.full_like(
31+
input_tensor, value, dtype=torch.float32, memory_format=torch.channels_last
32+
)
33+
34+
35+
@common.parametrize("test_data", FullLike.test_parameters)
36+
def test_preserves_kwargs_tosa_INT(test_data):
37+
pipeline = TosaPipelineINT[input_t1](
38+
FullLike(),
39+
test_data(),
40+
aten_op=[],
41+
exir_op=exir_op,
42+
)
43+
pipeline.run()
44+
45+
# Test that kwarg memory_format survived quantization.
46+
graph_module = pipeline.tester.get_artifact(StageType.EXPORT).graph_module
47+
nodes = graph_module.graph.nodes
48+
for n in nodes:
49+
if n.target == torch.ops.aten.full_like.default:
50+
assert n.meta["val"].dim_order() == (0, 2, 3, 1)
51+
break
52+
else:
53+
pytest.fail("Did not find torch.ops.aten.full_like")

0 commit comments

Comments
 (0)