Skip to content

Commit 5beaa57

Browse files
committed
[Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend
- add QNN layer norm support for aten.native_layer_norm.default - handle missing weight/bias by creating identity weight and zero bias - always provide bias tensor for QNN LayerNorm op - add floating-point and quantized tests for native_layer_norm - print generated pte filename after export
1 parent 40404fd commit 5beaa57

5 files changed

Lines changed: 96 additions & 27 deletions

File tree

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
import numpy as np
1313
import torch
14-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
14+
from executorch.backends.qualcomm.utils.constants import (
15+
QCOM_DATA,
16+
QCOM_QUANT_ATTRS,
17+
QCOM_ZERO_POINT,
18+
)
19+
from executorch.exir.dialects._ops import ops as exir_ops
1520

1621
from .node_visitor import NodeVisitor
1722
from .node_visitor_manager import register_node_visitor
@@ -31,6 +36,7 @@ def define_node(
3136
node: torch.fx.Node,
3237
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
3338
) -> PyQnnManager.PyQnnOpWrapper:
39+
# args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
3440
input_node = self.get_node(node.args[0])
3541
input_tensor = self.get_tensor(input_node, node)
3642
input_tensor_wrapper = self.define_tensor(
@@ -54,37 +60,61 @@ def define_node(
5460
axis = [len(input_tensor.shape) - 1]
5561
axis_shape = [len(axis)]
5662

57-
weight_node = self.get_node(node.args[2])
58-
if weight_node is not None:
63+
has_weight = len(node.args) > 2 and node.args[2] is not None
64+
if has_weight:
65+
weight_node = self.get_node(node.args[2])
5966
weight_tensor = get_parameter(weight_node, self.edge_program)
60-
weight_tensor_wrapper = self.define_tensor(
61-
weight_node,
62-
node,
63-
weight_tensor,
64-
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
65-
nodes_to_wrappers,
66-
)
67-
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
6867
else:
69-
warnings.warn(
70-
"[QNN Delegate Op Builder]: LayerNorm weight is None, skipping",
71-
stacklevel=1,
68+
# elementwise_affine=False: use all-ones weight as identity
69+
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
70+
weight_node = torch.fx.Node(
71+
node.graph,
72+
node.name + "_runtime_weight",
73+
"call_function",
74+
exir_ops.edge.aten.tensor.default,
75+
(),
76+
{},
7277
)
73-
layer_norm_input_tensors = [input_tensor_wrapper]
78+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
79+
quant_attrs = quant_attrs.copy()
80+
quant_attrs[QCOM_ZERO_POINT] = 0
81+
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
82+
weight_tensor_wrapper = self.define_tensor(
83+
weight_node,
84+
node,
85+
weight_tensor,
86+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
87+
nodes_to_wrappers,
88+
)
7489

75-
bias_node = self.get_node(node.args[3])
76-
if bias_node is not None:
90+
# Fake node: even when original bias is absent, QNN still needs it
91+
has_bias = len(node.args) > 3 and node.args[3] is not None
92+
if has_bias:
93+
bias_node = self.get_node(node.args[3])
7794
bias_tensor = get_parameter(bias_node, self.edge_program)
78-
bias_tensor_wrapper = self.define_tensor(
79-
bias_node,
80-
node,
81-
bias_tensor,
82-
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
83-
nodes_to_wrappers,
95+
else:
96+
bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32)
97+
bias_node = torch.fx.Node(
98+
node.graph,
99+
node.name + "_runtime_bias",
100+
"call_function",
101+
exir_ops.edge.aten.tensor.default,
102+
(),
103+
{},
84104
)
85-
layer_norm_input_tensors.append(bias_tensor_wrapper)
105+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
106+
quant_attrs = quant_attrs.copy()
107+
quant_attrs[QCOM_ZERO_POINT] = 0
108+
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
109+
bias_tensor_wrapper = self.define_tensor(
110+
bias_node,
111+
node,
112+
bias_tensor,
113+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
114+
nodes_to_wrappers,
115+
)
86116

87-
epsilon = node.args[4]
117+
epsilon = node.args[4] if len(node.args) > 4 else 1e-05
88118

89119
output_tensor = self.get_tensor(node, node, 0)
90120
output_tensor_wrapper = self.define_tensor(
@@ -100,7 +130,9 @@ def define_node(
100130
QNN_OP_PACKAGE_NAME_QTI_AISW,
101131
OpLayerNorm.op_name,
102132
)
103-
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
133+
layer_norm_op.AddInputTensors(
134+
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
135+
)
104136
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
105137
layer_norm_op.AddScalarParam(
106138
OpLayerNorm.param_epsilon,

backends/qualcomm/export_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ def build_executorch_binary(
617617
with open(pte_name, "wb") as file:
618618
exec_prog_mgr.write_to_file(file)
619619

620+
print(f"Successfully generated {pte_name}.")
620621
if qnn_config.compile_only:
621622
sys.exit(0)
622623

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
828828

829829

830830
@register_annotator(
831-
[torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name
831+
[torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
832+
QnnConstants.OpLayerNorm.op_name,
832833
)
833834
class LayerNorm(GeneralOpDef):
834835
@staticmethod

backends/qualcomm/tests/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,26 @@ def forward(self, x):
13881388
return self.linear(self.layer_norm(x))
13891389

13901390

1391+
class NativeLayerNorm(torch.nn.Module):
1392+
def __init__(self, affine=True):
1393+
super().__init__()
1394+
self.affine = affine
1395+
self.weight = torch.nn.Parameter(torch.ones(768))
1396+
self.bias = torch.nn.Parameter(torch.zeros(768))
1397+
self.normalized_shape = [768]
1398+
self.eps = 1e-6
1399+
1400+
def forward(self, x):
1401+
if self.affine:
1402+
return torch.native_layer_norm(
1403+
x, self.normalized_shape, self.weight, self.bias, self.eps
1404+
)[0]
1405+
else:
1406+
return torch.native_layer_norm(
1407+
x, self.normalized_shape, self.weight, self.bias, self.eps
1408+
)[0]
1409+
1410+
13911411
class LayerNormAdd(torch.nn.Module):
13921412
def __init__(self):
13931413
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,13 @@ def test_qnn_backend_layer_norm(self):
13841384
with self.subTest(i=i):
13851385
self.lower_module_and_test_output(module, sample_input)
13861386

1387+
def test_qnn_backend_native_layer_norm(self):
1388+
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
1389+
sample_input = (torch.randn(196, 768),)
1390+
for i, module in enumerate(modules):
1391+
with self.subTest(i=i):
1392+
self.lower_module_and_test_output(module, sample_input)
1393+
13871394
def test_qnn_backend_leaky_relu(self):
13881395
torch.manual_seed(8)
13891396
test_comb = [
@@ -3811,6 +3818,14 @@ def test_qnn_backend_layer_norm(self):
38113818
module = self.get_qdq_module(module, sample_input)
38123819
self.lower_module_and_test_output(module, sample_input)
38133820

3821+
def test_qnn_backend_native_layer_norm(self):
3822+
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
3823+
sample_input = (torch.randn(196, 768),)
3824+
for i, module in enumerate(modules):
3825+
with self.subTest(i=i):
3826+
module = self.get_qdq_module(module, sample_input)
3827+
self.lower_module_and_test_output(module, sample_input)
3828+
38143829
def test_qnn_backend_leaky_relu(self):
38153830
test_comb = [
38163831
{

0 commit comments

Comments
 (0)