Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions backends/arm/test/ops/test_bmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -12,9 +12,11 @@

from executorch.backends.arm.test import common

from executorch.backends.arm.quantizer import get_symmetric_a16w8_quantization_config
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
OpNotSupportedPipeline,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
Expand All @@ -23,9 +25,6 @@
aten_op_bmm = "torch.ops.aten.bmm.default"
exir_op_bmm = "executorch_exir_dialects_edge__ops_aten_bmm_default"

aten_op_mm = "torch.ops.aten.matmul.default"
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"

input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x


Expand Down Expand Up @@ -191,3 +190,48 @@
quantize=True,
)
pipeline.run()


a16w8_bmm_test_parameters = {
"rand_same": lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
"rand_diff": lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
"rand_rect": lambda: (torch.rand(1, 55, 3), torch.rand(1, 3, 44)),
"rand_batch10": lambda: (torch.rand(10, 1, 10), torch.rand(10, 10, 5)),
"rand_neg": lambda: (
-10 * torch.randn(2, 32, 64),
5 + 5 * torch.randn(2, 64, 32),
),
}


@common.parametrize("test_data", a16w8_bmm_test_parameters)
@common.XfailIfNoCorstone300
def test_bmm_a16w8_u55_INT(test_data: input_t1):
"""U55 does not support bmm with INT16 inputs. Verify bmm is rejected."""
pipeline = OpNotSupportedPipeline[input_t1](
BMM(),
test_data(),
non_delegated_ops={exir_op_bmm: 1},
n_expected_delegates=0,
u55_subset=True,
quantize=True,
tosa_extensions=["int16"],
)
pipeline.quantizer.set_global(get_symmetric_a16w8_quantization_config())
pipeline.run()


@common.parametrize("test_data", a16w8_bmm_test_parameters)
@common.XfailIfNoCorstone320
def test_bmm_a16w8_u85_INT(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
BMM(),
test_data(),
aten_op_bmm,
exir_op_bmm,
a16w8_quantization=True,
symmetric_io_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
73 changes: 73 additions & 0 deletions backends/arm/test/ops/test_conv1d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -399,3 +399,76 @@
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
)
pipeline.run()


# a16w8 (int16 activation, int8 weight) quantization test configurations
a16w8_conv1d_test_parameters = {
f"{k},per_channel_quant={q}": (lambda v=v, q=q: (v(), q))
for (k, v) in {
"k1_1x2x128_st1": lambda: Conv1d(
in_channels=2, out_channels=1, kernel_size=1,
stride=1, padding=0, length=128, batches=1,
),
"k3_1x3x64_st1_pd1": lambda: Conv1d(
in_channels=3, out_channels=4, kernel_size=3,
stride=1, padding=1, length=64, batches=1,
),
"k5_1x2x64_st1_pd2": lambda: Conv1d(
in_channels=2, out_channels=3, kernel_size=5,
stride=1, padding=2, length=64, batches=1,
),
"k3_1x3x32_st2_pd1": lambda: Conv1d(
in_channels=3, out_channels=4, kernel_size=3,
stride=2, padding=1, length=32, batches=1,
),
"k3_1x3x32_st1_dl2": lambda: Conv1d(
in_channels=3, out_channels=4, kernel_size=3,
stride=1, padding=0, dilation=2, length=32, batches=1,
),
"k3_1x4x32_st1_pd1_depthwise": lambda: Conv1d(
in_channels=4, out_channels=4, kernel_size=3,
stride=1, padding=1, groups=4, length=32, batches=1,
),
"k3_1x3x64_st1_pd1_nobias": lambda: Conv1d(
in_channels=3, out_channels=4, kernel_size=3,
stride=1, padding=1, bias=False, length=64, batches=1,
),
}.items()
for q in [True, False]
}


@common.parametrize("test_data", a16w8_conv1d_test_parameters)
@common.XfailIfNoCorstone300
def test_conv1d_a16w8_u55_INT(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU55PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
a16w8_quantization=True,
symmetric_io_quantization=True,
per_channel_quantization=per_channel_quantization,
qtol=128,
epsilon=2**-16,
)
pipeline.run()


@common.parametrize("test_data", a16w8_conv1d_test_parameters)
@common.XfailIfNoCorstone320
def test_conv1d_a16w8_u85_INT(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU85PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
a16w8_quantization=True,
symmetric_io_quantization=True,
per_channel_quantization=per_channel_quantization,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
39 changes: 39 additions & 0 deletions backends/arm/test/ops/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,42 @@ def test_exp_vgf_quant(test_data: Tuple):
quantize=True,
)
pipeline.run()


a16w8_exp_test_parameters = {
"rank1_rand": lambda: torch.rand(10),
"rank2_rand": lambda: torch.rand(8, 8) - 0.5,
"rank3_rand": lambda: torch.rand(1, 4, 4) * 2 - 1,
}


@common.parametrize("test_data", a16w8_exp_test_parameters)
@common.XfailIfNoCorstone300
def test_exp_a16w8_u55_INT(test_data: Tuple):
pipeline = EthosU55PipelineINT[input_t1](
Exp(),
(test_data(),),
aten_op,
exir_ops=[],
symmetric_io_quantization=True,
a16w8_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()


@common.parametrize("test_data", a16w8_exp_test_parameters)
@common.XfailIfNoCorstone320
def test_exp_a16w8_u85_INT(test_data: Tuple):
pipeline = EthosU85PipelineINT[input_t1](
Exp(),
(test_data(),),
aten_op,
exir_ops=[],
symmetric_io_quantization=True,
a16w8_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
39 changes: 39 additions & 0 deletions backends/arm/test/ops/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,42 @@ def test_gelu_vgf_quant(test_data: input_t1):
quantize=True,
)
pipeline.run()


a16w8_gelu_test_parameters = {
"rank1_rand": lambda: torch.rand(10),
"rank2_rand": lambda: torch.rand(8, 8) - 0.5,
"rank3_randn": lambda: torch.randn(1, 4, 4) + 2,
}


@common.parametrize("test_data", a16w8_gelu_test_parameters)
@common.XfailIfNoCorstone300
def test_gelu_a16w8_u55_INT(test_data: input_t1):
pipeline = EthosU55PipelineINT[input_t1](
Gelu(),
(test_data(),),
Gelu.aten_op,
Gelu.exir_op,
a16w8_quantization=True,
symmetric_io_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()


@common.parametrize("test_data", a16w8_gelu_test_parameters)
@common.XfailIfNoCorstone320
def test_gelu_a16w8_u85_INT(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
Gelu(),
(test_data(),),
Gelu.aten_op,
Gelu.exir_op,
a16w8_quantization=True,
symmetric_io_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
47 changes: 47 additions & 0 deletions backends/arm/test/ops/test_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,50 @@ def test_mean_tosa_INT(test_data):
symmetric_io_quantization=True,
)
pipeline.run()


a16w8_mean_test_parameters = {
"rank_1_keepdim": lambda: (torch.rand(7), 0, True),
"rank_2_keepdim": lambda: (torch.rand(7, 3), (0, 1), True),
"rank_3_keepdim": lambda: (torch.rand(5, 7, 3), (0, 1, 2), True),
"rand_1_keepdim": lambda: (torch.rand(1, 5, 7, 3), (1), True),
"rand_23_keepdim": lambda: (torch.rand(1, 5, 7, 3), (2, 3), True),
"rand_0123_keepdim": lambda: (torch.rand(1, 5, 7, 3), (0, 1, 2, 3), True),
"rand_none_keepdim": lambda: (torch.rand(1, 5, 7, 3), None, True),
"rank_1": lambda: (torch.rand(7), (-1), False),
"rank_2": lambda: (torch.rand(5, 7), (-2, -1), False),
"rand_3": lambda: (torch.rand(1, 5, 7, 3), (-1), False),
"rand_123": lambda: (torch.rand(1, 5, 7, 3), (-3, -2, -1), False),
}


@common.parametrize("test_data", a16w8_mean_test_parameters)
@common.XfailIfNoCorstone300
def test_mean_dim_a16w8_u55_INT(test_data):
test_data, dim, keep_dim = test_data()
pipeline = EthosU55PipelineINT[input_t](
MeanDim(dim, keep_dim),
(test_data,),
[],
symmetric_io_quantization=True,
a16w8_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()


@common.parametrize("test_data", a16w8_mean_test_parameters)
@common.XfailIfNoCorstone320
def test_mean_dim_a16w8_u85_INT(test_data):
test_data, dim, keep_dim = test_data()
pipeline = EthosU85PipelineINT[input_t](
MeanDim(dim, keep_dim),
(test_data,),
[],
symmetric_io_quantization=True,
a16w8_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
39 changes: 39 additions & 0 deletions backends/arm/test/ops/test_reciprocal.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,42 @@ def test_reciprocal_vgf_quant(test_data: torch.Tensor):
quantize=True,
)
pipeline.run()


a16w8_reciprocal_test_parameters = {
"rank1": lambda: torch.rand(10) + 0.5,
"rank2": lambda: torch.rand(5, 10) + 0.5,
"rank3": lambda: torch.rand(2, 5, 10) + 0.5,
}


@common.parametrize("test_data", a16w8_reciprocal_test_parameters)
@common.XfailIfNoCorstone300
def test_reciprocal_a16w8_u55_INT(test_data: torch.Tensor):
pipeline = EthosU55PipelineINT[input_t1](
Reciprocal(),
(test_data(),),
aten_op,
exir_ops=[],
a16w8_quantization=True,
symmetric_io_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()


@common.parametrize("test_data", a16w8_reciprocal_test_parameters)
@common.XfailIfNoCorstone320
def test_reciprocal_a16w8_u85_INT(test_data: torch.Tensor):
pipeline = EthosU85PipelineINT[input_t1](
Reciprocal(),
(test_data(),),
aten_op,
exir_ops=[],
a16w8_quantization=True,
symmetric_io_quantization=True,
qtol=128,
epsilon=2**-16,
)
pipeline.run()
Loading
Loading