Skip to content

Commit d718c90

Browse files
Add a16w8 per-op test for gelu
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.gelu` on Ethos-U55 and Ethos-U85. ## Changes - Add `a16w8_gelu_test_parameters` dict with 3 test configurations covering rank-1, rank-2, and rank-3 tensors - Add `test_gelu_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True, qtol=128, epsilon=2**-16` - Add `test_gelu_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Register `ops/test_gelu.py` in `fbcode/` and `xplat/` `targets.bzl` bypass-pytorch-oss-checks Differential Revision: D104532359
1 parent 400371e commit d718c90

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

backends/arm/test/ops/test_gelu.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,42 @@ def test_gelu_vgf_quant(test_data: input_t1):
176176
quantize=True,
177177
)
178178
pipeline.run()
179+
180+
181+
a16w8_gelu_test_parameters = {
182+
"rank1_rand": lambda: torch.rand(10),
183+
"rank2_rand": lambda: torch.rand(8, 8) - 0.5,
184+
"rank3_randn": lambda: torch.randn(1, 4, 4) + 2,
185+
}
186+
187+
188+
@common.parametrize("test_data", a16w8_gelu_test_parameters)
189+
@common.XfailIfNoCorstone300
190+
def test_gelu_a16w8_u55_INT(test_data: input_t1):
191+
pipeline = EthosU55PipelineINT[input_t1](
192+
Gelu(),
193+
(test_data(),),
194+
Gelu.aten_op,
195+
Gelu.exir_op,
196+
a16w8_quantization=True,
197+
symmetric_io_quantization=True,
198+
qtol=128,
199+
epsilon=2**-16,
200+
)
201+
pipeline.run()
202+
203+
204+
@common.parametrize("test_data", a16w8_gelu_test_parameters)
205+
@common.XfailIfNoCorstone320
206+
def test_gelu_a16w8_u85_INT(test_data: input_t1):
207+
pipeline = EthosU85PipelineINT[input_t1](
208+
Gelu(),
209+
(test_data(),),
210+
Gelu.aten_op,
211+
Gelu.exir_op,
212+
a16w8_quantization=True,
213+
symmetric_io_quantization=True,
214+
qtol=128,
215+
epsilon=2**-16,
216+
)
217+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def define_arm_tests():
4141
"ops/test_mean_dim.py",
4242
"ops/test_var.py",
4343
"ops/test_conv1d.py",
44+
"ops/test_gelu.py",
4445
]
4546

4647
# Quantization

0 commit comments

Comments
 (0)