Skip to content

Commit 8fd5ef8

Browse files
Add a16w8 per-op test for split
Summary: Add int16 activation / int8 weight (a16w8) quantization tests for `aten.split` on Ethos-U55 and Ethos-U85. ## Changes - Add `a16w8_split_test_parameters` dict with 3 test configurations covering 1D, 2D, and 3D splits along different axes - Add `test_split_a16w8_u55_INT` using `EthosU55PipelineINT` with `a16w8_quantization=True, symmetric_io_quantization=True, qtol=128, epsilon=2**-16` - Add `test_split_a16w8_u85_INT` using `EthosU85PipelineINT` with same kwargs - Register `ops/test_split.py` in `fbcode/` and `xplat/` `targets.bzl` bypass-pytorch-oss-checks Differential Revision: D104533281
1 parent 8cba275 commit 8fd5ef8

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

backends/arm/test/ops/test_split.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,42 @@ def test_split_tensor_vgf_quant(test_data: Tuple):
310310
quantize=True,
311311
)
312312
pipeline.run()
313+
314+
315+
a16w8_split_test_parameters = {
316+
"a16w8_1d_split_2": lambda: (torch.rand(10), 2, 0),
317+
"a16w8_2d_split_4": lambda: (torch.rand(8, 4), 4, 0),
318+
"a16w8_3d_split_4": lambda: (torch.rand(4, 4, 8), 4, 2),
319+
}
320+
321+
322+
@common.parametrize("test_data", a16w8_split_test_parameters)
323+
@common.XfailIfNoCorstone300
324+
def test_split_a16w8_u55_INT(test_data: input_t1):
325+
pipeline = EthosU55PipelineINT[input_t1](
326+
Split(),
327+
test_data(),
328+
aten_ops=[],
329+
exir_ops=exir_op,
330+
a16w8_quantization=True,
331+
symmetric_io_quantization=True,
332+
qtol=1,
333+
epsilon=2**-16,
334+
)
335+
pipeline.run()
336+
337+
338+
@common.parametrize("test_data", a16w8_split_test_parameters)
339+
@common.XfailIfNoCorstone320
340+
def test_split_a16w8_u85_INT(test_data: input_t1):
341+
pipeline = EthosU85PipelineINT[input_t1](
342+
Split(),
343+
test_data(),
344+
aten_ops=[],
345+
exir_ops=exir_op,
346+
a16w8_quantization=True,
347+
symmetric_io_quantization=True,
348+
qtol=1,
349+
epsilon=2**-16,
350+
)
351+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def define_arm_tests():
4343
"ops/test_conv1d.py",
4444
"ops/test_gelu.py",
4545
"ops/test_bmm.py",
46+
"ops/test_split.py",
4647
]
4748

4849
# Quantization

0 commit comments

Comments
 (0)