Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_DTYPE,
QCOM_ENCODING,
Expand Down Expand Up @@ -130,7 +130,7 @@ def _annotate_quant_attrs(
self._annotate_requant(n)
# With fold_quant enabled, check if the input of dq op is quantized param.
param = None
if n.target in dq_ops:
if n.target in dq_ops and is_parameter(n.args[0], self.edge_program):
param = get_parameter(n.args[0], self.edge_program)
if n.target not in q_ops and param is None:
continue
Expand Down
38 changes: 32 additions & 6 deletions backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_QUANT_ATTRS,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand All @@ -31,6 +36,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
Expand All @@ -54,8 +60,26 @@ def define_node(
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]

weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
has_weight = len(node.args) > 2 and node.args[2] is not None
if has_weight:
weight_node = self.get_node(node.args[2])
assert weight_node is not None
weight_tensor = get_parameter(weight_node, self.edge_program)
else:
# elementwise_affine=False: use all-ones weight as identity
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
weight_node = torch.fx.Node(
node.graph,
node.name + "_runtime_weight",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
Expand All @@ -66,8 +90,10 @@ def define_node(

layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]

bias_node = self.get_node(node.args[3])
if bias_node is not None:
has_bias = len(node.args) > 3 and node.args[3] is not None
if has_bias:
bias_node = self.get_node(node.args[3])
assert bias_node is not None
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
Expand All @@ -78,7 +104,7 @@ def define_node(
)
layer_norm_input_tensors.append(bias_tensor_wrapper)

epsilon = node.args[4]
epsilon = node.args[4] if len(node.args) > 4 else 1e-05

output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
Expand Down
10 changes: 6 additions & 4 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ def get_parameter(
param = get_buffer(edge_program, node)
if is_lifted_tensor_constant(edge_program, node):
param = get_lifted_tensor_constant(edge_program, node)
if param is not None:
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
assert param is not None, (
f"Expect {node.name} to be parameter, buffer, or lifted tensor constant"
)
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
Comment on lines +40 to +45
return param


Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def build_executorch_binary(
with open(pte_name, "wb") as file:
exec_prog_mgr.write_to_file(file)

print(f"Successfully generated {pte_name}.")
if qnn_config.compile_only:
sys.exit(0)

Expand Down
39 changes: 20 additions & 19 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,16 +828,15 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:


@register_annotator(
[torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name
[torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
QnnConstants.OpLayerNorm.op_name,
)
class LayerNorm(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
weight_node = node.args[2] if len(node.args) > 2 else None
bias_node = node.args[3] if len(node.args) > 3 else None

if _is_annotated([node]):
return
Expand All @@ -848,20 +847,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
nodes_to_mark_annotated = [node]
if isinstance(weight_node, Node):
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated.append(weight_node)
if isinstance(bias_node, Node):
annotate_input_qspec_map(
node,
bias_node,
Expand Down
36 changes: 18 additions & 18 deletions backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,8 @@ class LayerNorm(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
weight_node = node.args[2] if len(node.args) > 2 else None
bias_node = node.args[3] if len(node.args) > 3 else None
Comment on lines 475 to +479
Comment on lines 475 to +479

if _is_annotated([node]):
return
Expand All @@ -489,20 +487,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
nodes_to_mark_annotated = [node]
if isinstance(weight_node, Node):
if input_act_qspec.dtype == torch.int32:
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated.append(weight_node)
if isinstance(bias_node, Node):
annotate_input_qspec_map(
node,
bias_node,
Expand Down
9 changes: 7 additions & 2 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,9 +1379,14 @@ def forward(self, x):


class LayerNorm(torch.nn.Module):
def __init__(self, bias=True):
def __init__(self, elementwise_affine=True, bias=True):
super().__init__()
self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
self.layer_norm = torch.nn.LayerNorm(
[768],
eps=1e-6,
elementwise_affine=elementwise_affine,
bias=bias,
)
self.linear = torch.nn.Linear(768, 196)

def forward(self, x):
Expand Down
12 changes: 10 additions & 2 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
self.lower_module_and_test_output(module, sample_input)
Comment on lines 1381 to 1385

def test_qnn_backend_layer_norm(self):
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
modules = [
LayerNorm(), # noqa: F405
LayerNorm(bias=False), # noqa: F405
LayerNorm(elementwise_affine=False), # noqa: F405
]
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down Expand Up @@ -3804,7 +3808,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_layer_norm(self):
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
modules = [
LayerNorm(), # noqa: F405
LayerNorm(bias=False), # noqa: F405
LayerNorm(elementwise_affine=False), # noqa: F405
]
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down
Loading