Skip to content

Commit 89587c2

Browse files
authored
Qualcomm AI Engine Direct - Fixed Conv2d + PReLu fusion issue (#19014)
1 parent 0919746 commit 89587c2

3 files changed

Lines changed: 108 additions & 14 deletions

File tree

backends/qualcomm/builders/op_prelu.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
99

1010
import torch
11-
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
1211

1312
from .node_visitor import get_parameter, NodeVisitor
1413
from .node_visitor_manager import register_node_visitor
@@ -38,19 +37,8 @@ def define_node(
3837
)
3938

4039
coeff_node = self.get_node(node.args[1])
41-
coeff = get_parameter(coeff_node, self.edge_program)
42-
coeff_tensor = torch.zeros(input_node.meta["val"].shape, dtype=coeff.dtype)
43-
# per-channel activation
44-
coeff_node_shape = coeff_node.meta["val"].shape
45-
if len(coeff_node_shape) and coeff_node_shape[0] > 1:
46-
for i in range(input_node.meta["val"].shape[1]):
47-
coeff_tensor = coeff_tensor.index_fill(1, torch.tensor([i]), coeff[i])
48-
else:
49-
coeff_tensor.fill_(coeff[0] if coeff.dim() else coeff)
50-
51-
if axis_order := input_node.meta.get(QCOM_AXIS_ORDER, None):
52-
coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous()
53-
40+
coeff_tensor = get_parameter(coeff_node, self.edge_program)
41+
# The coeff_tensor would be broadcasted to match the input shape by QNN
5442
coeff_tensor_wrapper = self.define_tensor(
5543
coeff_node,
5644
node,

backends/qualcomm/tests/models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,16 @@ def forward(self, x):
674674
return torch.flip(x, self.dims)
675675

676676

677+
class Conv2dLeakyReLU(torch.nn.Module):
678+
def __init__(self, negative_slope=0.01):
679+
super().__init__()
680+
self.conv = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
681+
self.leaky_relu = torch.nn.LeakyReLU(negative_slope)
682+
683+
def forward(self, x):
684+
return self.leaky_relu(self.conv(x))
685+
686+
677687
class Conv2dMaxPool2d(torch.nn.Module):
678688
def __init__(self):
679689
super().__init__()
@@ -690,6 +700,16 @@ def forward(self, x):
690700
return self.pool(self.conv(x))
691701

692702

703+
class Conv2dReLU(torch.nn.Module):
704+
def __init__(self):
705+
super().__init__()
706+
self.conv = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
707+
self.relu = torch.nn.ReLU()
708+
709+
def forward(self, x):
710+
return self.relu(self.conv(x))
711+
712+
693713
class Conv2dSequential(torch.nn.Module):
694714
def __init__(self, bias=True, channel_last=False):
695715
super().__init__()
@@ -1480,6 +1500,16 @@ def forward(self, x):
14801500
return self.linear(x)
14811501

14821502

1503+
class LinearLeakyReLU(torch.nn.Module):
1504+
def __init__(self, negative_slope=0.01):
1505+
super().__init__()
1506+
self.linear = torch.nn.Linear(32, 32)
1507+
self.leaky_relu = torch.nn.LeakyReLU(negative_slope)
1508+
1509+
def forward(self, x):
1510+
return self.leaky_relu(self.linear(x))
1511+
1512+
14831513
class LinearNonConstantWeight(torch.nn.Module):
14841514
def __init__(self):
14851515
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4873,6 +4873,82 @@ def test_qnn_backend_conv2d_max_pool2d(self):
48734873
module = self.get_qdq_module(module, sample_input)
48744874
self.lower_module_and_test_output(module, sample_input)
48754875

4876+
def test_qnn_backend_activation_fusion(self):
4877+
if self.enable_x86_64:
4878+
self.skipTest(
4879+
"At the moment, testing is only being conducted on the device."
4880+
)
4881+
test_cases = [
4882+
{
4883+
"name": "conv2d_leaky_relu",
4884+
QCOM_MODULE: Conv2dLeakyReLU(), # noqa: F405
4885+
QCOM_SAMPLE_INPUTS: (torch.randn(1, 32, 6, 2),),
4886+
"unfused_check": lambda ops: any(
4887+
"prelu.opt" in op.lower() for op in ops
4888+
),
4889+
"unfused_msg": "Unexpected PReLU op in HTP ops (LeakyReLU lowered to PReLU)",
4890+
},
4891+
{
4892+
"name": "conv2d_relu",
4893+
QCOM_MODULE: Conv2dReLU(), # noqa: F405
4894+
QCOM_SAMPLE_INPUTS: (torch.randn(1, 3, 28, 28),),
4895+
"unfused_check": lambda ops: any(
4896+
op.lower() in ("q::relu", "q::relu.opt")
4897+
or (("relu" in op.lower()) and ("conv" not in op.lower()))
4898+
for op in ops
4899+
),
4900+
"unfused_msg": "Unexpected standalone ReLU op in HTP ops",
4901+
},
4902+
{
4903+
"name": "linear_leaky_relu",
4904+
QCOM_MODULE: LinearLeakyReLU(), # noqa: F405
4905+
QCOM_SAMPLE_INPUTS: (torch.randn(1, 6, 2, 32),),
4906+
"unfused_check": lambda ops: any(
4907+
"prelu.opt" in op.lower() for op in ops
4908+
),
4909+
"unfused_msg": "Unexpected PReLU op in HTP ops (LeakyReLU lowered to PReLU)",
4910+
},
4911+
]
4912+
for tc in test_cases:
4913+
with self.subTest(tc["name"]):
4914+
torch.manual_seed(8)
4915+
module = self.get_qdq_module(tc[QCOM_MODULE], tc[QCOM_SAMPLE_INPUTS])
4916+
backend_options = generate_htp_compiler_spec(use_fp16=False)
4917+
compiler_spec = generate_qnn_executorch_compiler_spec(
4918+
soc_model=self.chipset_table[TestQNN.soc_model],
4919+
backend_options=backend_options,
4920+
profile_level=3,
4921+
)
4922+
with tempfile.TemporaryDirectory() as tmp_dir:
4923+
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
4924+
module, tc[QCOM_SAMPLE_INPUTS], compiler_spec
4925+
).to_executorch()
4926+
pte_path = f"{tmp_dir}/model.pte"
4927+
with open(pte_path, "wb") as f:
4928+
edge_prog_mgr.write_to_file(f)
4929+
adb = self.get_adb_tool(pte_path)
4930+
binaries_trace = generate_optrace(
4931+
tmp_dir,
4932+
self.chipset_table[TestQNN.soc_model],
4933+
adb,
4934+
pte_path,
4935+
[tc[QCOM_SAMPLE_INPUTS]],
4936+
)
4937+
htp_ops = []
4938+
for _, (_, qhas) in binaries_trace.items():
4939+
with open(qhas, "r") as qhas_file:
4940+
qhas_data = json.load(qhas_file)
4941+
for row in qhas_data["data"]["htp_op_types"]["data"]:
4942+
htp_ops.append(row["op"])
4943+
has_conv = any("ConvLayer" in op for op in htp_ops)
4944+
self.assertTrue(
4945+
has_conv, f"Expected Conv op in HTP ops, got: {htp_ops}"
4946+
)
4947+
self.assertFalse(
4948+
tc["unfused_check"](htp_ops),
4949+
f"{tc['unfused_msg']}, got: {htp_ops}",
4950+
)
4951+
48764952
def test_qnn_backend_conv2d_slice_copy(self):
48774953
module = Conv2dSliceCopy() # noqa: F405
48784954
sample_input = (torch.randn([2, 1, 3, 3]),)

0 commit comments

Comments
 (0)