Skip to content

Commit 477f3fc

Browse files
committed
Add UT
1 parent 19bc9b4 commit 477f3fc

2 files changed

Lines changed: 49 additions & 0 deletions

File tree

backends/qualcomm/tests/models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,31 @@ def forward(self, x):
875875
return self.second(self.first(x))
876876

877877

878+
class ConvFull(torch.nn.Module):
879+
def __init__(self, fill, full_shape):
880+
super().__init__()
881+
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
882+
self.fill = fill
883+
self.full_shape = full_shape
884+
885+
def forward(self, x):
886+
y = self.conv(x)
887+
c = torch.full(self.full_shape, self.fill, dtype=y.dtype)
888+
return torch.cat([y, c], dim=1)
889+
890+
891+
class ConvFullLike(torch.nn.Module):
892+
def __init__(self, fill):
893+
super().__init__()
894+
self.conv = torch.nn.Conv2d(8, 16, 3, padding=1)
895+
self.fill = fill
896+
897+
def forward(self, x):
898+
y = self.conv(x)
899+
c = torch.full_like(y, self.fill)
900+
return torch.cat([y, c], dim=1)
901+
902+
878903
class ConvTranspose1dSingle(torch.nn.Module):
879904
def __init__(self, bias=True, dilation=1):
880905
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,6 +2344,17 @@ def test_qnn_backend_einsum_outer_product_relu(self):
23442344
)
23452345
self.lower_module_and_test_output(module, sample_input)
23462346

2347+
def test_qnn_backend_full_layout_transformed(self):
2348+
full_shape = (1, 16, 4, 6)
2349+
module = ConvFull(0.5, full_shape) # noqa: F405
2350+
sample_input = (torch.randn(1, 8, 4, 6),)
2351+
self.lower_module_and_test_output(module, sample_input)
2352+
2353+
def test_qnn_backend_full_like_layout_transformed(self):
2354+
module = ConvFullLike(0.5) # noqa: F405
2355+
sample_input = (torch.randn(1, 8, 4, 6),)
2356+
self.lower_module_and_test_output(module, sample_input)
2357+
23472358
# TODO: Create a new UT class for passes specific checks
23482359
def test_qnn_backend_lift_add_tensor(self):
23492360
module = LiftAddTensor() # noqa: F405
@@ -5095,6 +5106,19 @@ def test_qnn_backend_einsum_outer_product_relu(self):
50955106
module = self.get_qdq_module(module, sample_input)
50965107
self.lower_module_and_test_output(module, sample_input)
50975108

5109+
def test_qnn_backend_full_layout_transformed(self):
5110+
full_shape = (1, 16, 4, 6)
5111+
module = ConvFull(0.5, full_shape) # noqa: F405
5112+
sample_input = (torch.randn(1, 8, 4, 6),)
5113+
module = self.get_qdq_module(module, sample_input)
5114+
self.lower_module_and_test_output(module, sample_input)
5115+
5116+
def test_qnn_backend_full_like_layout_transformed(self):
5117+
module = ConvFullLike(0.5) # noqa: F405
5118+
sample_input = (torch.randn(1, 8, 4, 6),)
5119+
module = self.get_qdq_module(module, sample_input)
5120+
self.lower_module_and_test_output(module, sample_input)
5121+
50985122
@unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35")
50995123
def test_qnn_backend_masked_softmax(self):
51005124
if self.enable_x86_64:

0 commit comments

Comments
 (0)