Skip to content

Commit a8cfe69

Browse files
3l1facebook-github-bot
authored andcommitted
Add multi-reader tests for Add/Sub rescale fusion (#18758)
Summary: Add AddMultiReader and SubMultiReader test models (conv2(conv1(x)) +/- conv3(conv1(x))) where conv1's output Rescale has two readers. These exercise the multi-reader per-consumer fusion loop. Differential Revision: D99939008
1 parent 063f9c9 commit a8cfe69

2 files changed

Lines changed: 90 additions & 0 deletions

File tree

backends/arm/test/ops/test_add.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,52 @@ def test_add_dual_conv_u85_INT(test_data: input_t1):
379379
pipeline.run()
380380

381381

382+
class AddMultiReader(torch.nn.Module):
383+
"""Conv2(conv1(x)) + conv3(conv1(x)) — conv1's output Rescale has two
384+
readers.
385+
"""
386+
387+
def __init__(self):
388+
super().__init__()
389+
self.conv1 = torch.nn.Conv2d(3, 3, 1, bias=False)
390+
self.conv2 = torch.nn.Conv2d(3, 3, 1, bias=False)
391+
self.conv3 = torch.nn.Conv2d(3, 3, 1, bias=False)
392+
393+
def forward(self, x):
394+
y = self.conv1(x)
395+
return self.conv2(y) + self.conv3(y)
396+
397+
test_data = {
398+
"4d_randn": lambda: (torch.randn(1, 3, 4, 4),),
399+
}
400+
401+
402+
@common.parametrize("test_data", AddMultiReader.test_data)
403+
def test_add_multi_reader_tosa_INT(test_data: input_t1):
404+
pipeline = TosaPipelineINT[input_t1](
405+
AddMultiReader(), test_data(), aten_op, exir_op
406+
)
407+
pipeline.run()
408+
409+
410+
@common.parametrize("test_data", AddMultiReader.test_data)
411+
@common.XfailIfNoCorstone300
412+
def test_add_multi_reader_u55_INT(test_data: input_t1):
413+
pipeline = EthosU55PipelineINT[input_t1](
414+
AddMultiReader(), test_data(), aten_op, exir_op
415+
)
416+
pipeline.run()
417+
418+
419+
@common.parametrize("test_data", AddMultiReader.test_data)
420+
@common.XfailIfNoCorstone320
421+
def test_add_multi_reader_u85_INT(test_data: input_t1):
422+
pipeline = EthosU85PipelineINT[input_t1](
423+
AddMultiReader(), test_data(), aten_op, exir_op
424+
)
425+
pipeline.run()
426+
427+
382428
@common.parametrize("test_data", Add.test_data)
383429
def test_add_tensor_tosa_INT_16a8w(test_data: input_t1):
384430
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit

backends/arm/test/ops/test_sub.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,50 @@ def test_sub_dual_conv_u85_INT(test_data: input_t1):
394394
pipeline.run()
395395

396396

397+
class SubMultiReader(torch.nn.Module):
398+
"""conv2(conv1(x)) - conv3(conv1(x)) — conv1's output Rescale has two readers."""
399+
400+
def __init__(self):
401+
super().__init__()
402+
self.conv1 = torch.nn.Conv2d(3, 3, 1, bias=False)
403+
self.conv2 = torch.nn.Conv2d(3, 3, 1, bias=False)
404+
self.conv3 = torch.nn.Conv2d(3, 3, 1, bias=False)
405+
406+
def forward(self, x):
407+
y = self.conv1(x)
408+
return self.conv2(y) - self.conv3(y)
409+
410+
test_data = {
411+
"4d_randn": lambda: (torch.randn(1, 3, 4, 4),),
412+
}
413+
414+
415+
@common.parametrize("test_data", SubMultiReader.test_data)
416+
def test_sub_multi_reader_tosa_INT(test_data: input_t1):
417+
pipeline = TosaPipelineINT[input_t1](
418+
SubMultiReader(), test_data(), aten_op, exir_op
419+
)
420+
pipeline.run()
421+
422+
423+
@common.parametrize("test_data", SubMultiReader.test_data)
424+
@common.XfailIfNoCorstone300
425+
def test_sub_multi_reader_u55_INT(test_data: input_t1):
426+
pipeline = EthosU55PipelineINT[input_t1](
427+
SubMultiReader(), test_data(), aten_op, exir_op
428+
)
429+
pipeline.run()
430+
431+
432+
@common.parametrize("test_data", SubMultiReader.test_data)
433+
@common.XfailIfNoCorstone320
434+
def test_sub_multi_reader_u85_INT(test_data: input_t1):
435+
pipeline = EthosU85PipelineINT[input_t1](
436+
SubMultiReader(), test_data(), aten_op, exir_op
437+
)
438+
pipeline.run()
439+
440+
397441
@common.parametrize("test_data", sub_test_data)
398442
def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1):
399443
"""Test sub operation with 16A8W quantization (16-bit activations, 8-bit

0 commit comments

Comments
 (0)