Skip to content

Commit 60842b0

Browse files
Add a16w8 per-op test for gelu
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.gelu` on Ethos-U55 and Ethos-U85. ## Changes - Add `a16w8_gelu_test_parameters` dict with 3 test configurations covering rank-1, rank-2, and rank-3 tensors - Add `test_gelu_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True, qtol=128, epsilon=2**-16` - Add `test_gelu_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Register `ops/test_gelu.py` in `fbcode/` and `xplat/` `targets.bzl` bypass-pytorch-oss-checks Differential Revision: D104532359
1 parent 91ce246 commit 60842b0

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

backends/arm/test/ops/test_gelu.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Tuple
77

88
import torch
9+
910
from executorch.backends.arm.test import common
1011
from executorch.backends.arm.test.tester.test_pipeline import (
1112
EthosU55PipelineINT,
@@ -176,3 +177,42 @@ def test_gelu_vgf_quant(test_data: input_t1):
176177
quantize=True,
177178
)
178179
pipeline.run()
180+
181+
182+
a16w8_gelu_test_parameters = {
183+
"rank1_rand": lambda: torch.rand(10),
184+
"rank2_rand": lambda: torch.rand(8, 8) - 0.5,
185+
"rank3_randn": lambda: torch.randn(1, 4, 4) + 2,
186+
}
187+
188+
189+
@common.parametrize("test_data", a16w8_gelu_test_parameters)
190+
@common.XfailIfNoCorstone300
191+
def test_gelu_a16w8_u55_INT(test_data: input_t1):
192+
pipeline = EthosU55PipelineINT[input_t1](
193+
Gelu(),
194+
(test_data(),),
195+
Gelu.aten_op,
196+
Gelu.exir_op,
197+
a16w8_quantization=True,
198+
symmetric_io_quantization=True,
199+
qtol=128,
200+
epsilon=2**-16,
201+
)
202+
pipeline.run()
203+
204+
205+
@common.parametrize("test_data", a16w8_gelu_test_parameters)
206+
@common.XfailIfNoCorstone320
207+
def test_gelu_a16w8_u85_INT(test_data: input_t1):
208+
pipeline = EthosU85PipelineINT[input_t1](
209+
Gelu(),
210+
(test_data(),),
211+
Gelu.aten_op,
212+
Gelu.exir_op,
213+
a16w8_quantization=True,
214+
symmetric_io_quantization=True,
215+
qtol=128,
216+
epsilon=2**-16,
217+
)
218+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def define_arm_tests():
4141
"ops/test_mean_dim.py",
4242
"ops/test_var.py",
4343
"ops/test_conv1d.py",
44+
"ops/test_gelu.py",
4445
]
4546

4647
# Quantization

0 commit comments

Comments
 (0)