1111
1212import numpy as np
1313import torch
14- from executorch .backends .qualcomm .utils .constants import QCOM_DATA
14+ from executorch .backends .qualcomm .utils .constants import (
15+ QCOM_DATA ,
16+ QCOM_QUANT_ATTRS ,
17+ QCOM_ZERO_POINT ,
18+ )
19+ from executorch .exir .dialects ._ops import ops as exir_ops
1520
1621from .node_visitor import NodeVisitor
1722from .node_visitor_manager import register_node_visitor
@@ -31,6 +36,7 @@ def define_node(
3136 node : torch .fx .Node ,
3237 nodes_to_wrappers : Dict [torch .fx .Node , PyQnnManager .TensorWrapper ],
3338 ) -> PyQnnManager .PyQnnOpWrapper :
39+ # args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
3440 input_node = self .get_node (node .args [0 ])
3541 input_tensor = self .get_tensor (input_node , node )
3642 input_tensor_wrapper = self .define_tensor (
@@ -54,37 +60,61 @@ def define_node(
5460 axis = [len (input_tensor .shape ) - 1 ]
5561 axis_shape = [len (axis )]
5662
57- weight_node = self .get_node (node .args [2 ])
58- if weight_node is not None :
63+ has_weight = len (node .args ) > 2 and node .args [2 ] is not None
64+ if has_weight :
65+ weight_node = self .get_node (node .args [2 ])
5966 weight_tensor = get_parameter (weight_node , self .edge_program )
60- weight_tensor_wrapper = self .define_tensor (
61- weight_node ,
62- node ,
63- weight_tensor ,
64- PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
65- nodes_to_wrappers ,
66- )
67- layer_norm_input_tensors = [input_tensor_wrapper , weight_tensor_wrapper ]
6867 else :
69- warnings .warn (
70- "[QNN Delegate Op Builder]: LayerNorm weight is None, skipping" ,
71- stacklevel = 1 ,
68+ # elementwise_affine=False: use all-ones weight as identity
69+ weight_tensor = torch .ones (normalized_shapes , dtype = torch .float32 )
70+ weight_node = torch .fx .Node (
71+ node .graph ,
72+ node .name + "_runtime_weight" ,
73+ "call_function" ,
74+ exir_ops .edge .aten .tensor .default ,
75+ (),
76+ {},
7277 )
73- layer_norm_input_tensors = [input_tensor_wrapper ]
78+ if quant_attrs := node .meta .get (QCOM_QUANT_ATTRS ):
79+ quant_attrs = quant_attrs .copy ()
80+ quant_attrs [QCOM_ZERO_POINT ] = 0
81+ weight_node .meta [QCOM_QUANT_ATTRS ] = quant_attrs
82+ weight_tensor_wrapper = self .define_tensor (
83+ weight_node ,
84+ node ,
85+ weight_tensor ,
86+ PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
87+ nodes_to_wrappers ,
88+ )
7489
75- bias_node = self .get_node (node .args [3 ])
76- if bias_node is not None :
90+ # Fake node: even when original bias is absent, QNN still needs it
91+ has_bias = len (node .args ) > 3 and node .args [3 ] is not None
92+ if has_bias :
93+ bias_node = self .get_node (node .args [3 ])
7794 bias_tensor = get_parameter (bias_node , self .edge_program )
78- bias_tensor_wrapper = self .define_tensor (
79- bias_node ,
80- node ,
81- bias_tensor ,
82- PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
83- nodes_to_wrappers ,
95+ else :
96+ bias_tensor = torch .zeros (normalized_shapes , dtype = torch .float32 )
97+ bias_node = torch .fx .Node (
98+ node .graph ,
99+ node .name + "_runtime_bias" ,
100+ "call_function" ,
101+ exir_ops .edge .aten .tensor .default ,
102+ (),
103+ {},
84104 )
85- layer_norm_input_tensors .append (bias_tensor_wrapper )
105+ if quant_attrs := node .meta .get (QCOM_QUANT_ATTRS ):
106+ quant_attrs = quant_attrs .copy ()
107+ quant_attrs [QCOM_ZERO_POINT ] = 0
108+ bias_node .meta [QCOM_QUANT_ATTRS ] = quant_attrs
109+ bias_tensor_wrapper = self .define_tensor (
110+ bias_node ,
111+ node ,
112+ bias_tensor ,
113+ PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_STATIC ,
114+ nodes_to_wrappers ,
115+ )
86116
87- epsilon = node .args [4 ]
117+ epsilon = node .args [4 ] if len ( node . args ) > 4 else 1e-05
88118
89119 output_tensor = self .get_tensor (node , node , 0 )
90120 output_tensor_wrapper = self .define_tensor (
@@ -100,7 +130,9 @@ def define_node(
100130 QNN_OP_PACKAGE_NAME_QTI_AISW ,
101131 OpLayerNorm .op_name ,
102132 )
103- layer_norm_op .AddInputTensors (layer_norm_input_tensors )
133+ layer_norm_op .AddInputTensors (
134+ [input_tensor_wrapper , weight_tensor_wrapper , bias_tensor_wrapper ]
135+ )
104136 layer_norm_op .AddOutputTensors ([output_tensor_wrapper ])
105137 layer_norm_op .AddScalarParam (
106138 OpLayerNorm .param_epsilon ,
0 commit comments