Skip to content

Commit 77a9a41

Browse files
Add a16w8 per-op test for bmm (#19599)
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.bmm` on Ethos-U55 and Ethos-U85. ## Changes - Add `a16w8_bmm_test_parameters` dict with 5 test configurations covering same-shape, different-shape, rectangular, batch-10, and negative-value tensors - Add `test_bmm_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True, qtol=128, epsilon=2**-16` - Add `test_bmm_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Remove unused `aten_op_mm` and `exir_op_mm` variables - Register `ops/test_bmm.py` in `fbcode/` and `xplat/` `targets.bzl` bypass-pytorch-oss-checks Differential Revision: D104532363
1 parent 1371cae commit 77a9a41

2 files changed

Lines changed: 49 additions & 4 deletions

File tree

backends/arm/test/ops/test_bmm.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -12,9 +12,11 @@
1212

1313
from executorch.backends.arm.test import common
1414

15+
from executorch.backends.arm.quantizer import get_symmetric_a16w8_quantization_config
1516
from executorch.backends.arm.test.tester.test_pipeline import (
1617
EthosU55PipelineINT,
1718
EthosU85PipelineINT,
19+
OpNotSupportedPipeline,
1820
TosaPipelineFP,
1921
TosaPipelineINT,
2022
VgfPipeline,
@@ -23,9 +25,6 @@
2325
aten_op_bmm = "torch.ops.aten.bmm.default"
2426
exir_op_bmm = "executorch_exir_dialects_edge__ops_aten_bmm_default"
2527

26-
aten_op_mm = "torch.ops.aten.matmul.default"
27-
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
28-
2928
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
3029

3130

@@ -191,3 +190,48 @@ def test_bmm_vgf_quant_single_input(test_data: input_t1):
191190
quantize=True,
192191
)
193192
pipeline.run()
193+
194+
195+
a16w8_bmm_test_parameters = {
196+
"rand_same": lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
197+
"rand_diff": lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
198+
"rand_rect": lambda: (torch.rand(1, 55, 3), torch.rand(1, 3, 44)),
199+
"rand_batch10": lambda: (torch.rand(10, 1, 10), torch.rand(10, 10, 5)),
200+
"rand_neg": lambda: (
201+
-10 * torch.randn(2, 32, 64),
202+
5 + 5 * torch.randn(2, 64, 32),
203+
),
204+
}
205+
206+
207+
@common.parametrize("test_data", a16w8_bmm_test_parameters)
208+
@common.XfailIfNoCorstone300
209+
def test_bmm_a16w8_u55_INT(test_data: input_t1):
210+
"""U55 does not support bmm with INT16 inputs. Verify bmm is rejected."""
211+
pipeline = OpNotSupportedPipeline[input_t1](
212+
BMM(),
213+
test_data(),
214+
non_delegated_ops={exir_op_bmm: 1},
215+
n_expected_delegates=0,
216+
u55_subset=True,
217+
quantize=True,
218+
tosa_extensions=["int16"],
219+
)
220+
pipeline.quantizer.set_global(get_symmetric_a16w8_quantization_config())
221+
pipeline.run()
222+
223+
224+
@common.parametrize("test_data", a16w8_bmm_test_parameters)
225+
@common.XfailIfNoCorstone320
226+
def test_bmm_a16w8_u85_INT(test_data: input_t1):
227+
pipeline = EthosU85PipelineINT[input_t1](
228+
BMM(),
229+
test_data(),
230+
aten_op_bmm,
231+
exir_op_bmm,
232+
a16w8_quantization=True,
233+
symmetric_io_quantization=True,
234+
qtol=128,
235+
epsilon=2**-16,
236+
)
237+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def define_arm_tests():
4242
"ops/test_var.py",
4343
"ops/test_conv1d.py",
4444
"ops/test_gelu.py",
45+
"ops/test_bmm.py",
4546
]
4647

4748
# Quantization

0 commit comments

Comments
 (0)