Skip to content

Commit 40404fd

Browse files
KevinUW114514claude
andcommitted
[QNN] Guard get_parameter against node=None in LayerNormVisitor
Fixes AttributeError when aten.native_layer_norm has optional weight=None. Both weight and bias are guarded to handle the None case gracefully. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ee9a981 commit 40404fd

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,22 @@ def define_node(
5555
axis_shape = [len(axis)]
5656

5757
weight_node = self.get_node(node.args[2])
58-
weight_tensor = get_parameter(weight_node, self.edge_program)
59-
weight_tensor_wrapper = self.define_tensor(
60-
weight_node,
61-
node,
62-
weight_tensor,
63-
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
64-
nodes_to_wrappers,
65-
)
66-
67-
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
58+
if weight_node is not None:
59+
weight_tensor = get_parameter(weight_node, self.edge_program)
60+
weight_tensor_wrapper = self.define_tensor(
61+
weight_node,
62+
node,
63+
weight_tensor,
64+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
65+
nodes_to_wrappers,
66+
)
67+
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
68+
else:
69+
warnings.warn(
70+
"[QNN Delegate Op Builder]: LayerNorm weight is None, skipping",
71+
stacklevel=1,
72+
)
73+
layer_norm_input_tensors = [input_tensor_wrapper]
6874

6975
bias_node = self.get_node(node.args[3])
7076
if bias_node is not None:

backends/qualcomm/builders/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def is_parameter(
2929

3030
def get_parameter(
3131
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
32-
) -> torch.Tensor:
32+
) -> Optional[torch.Tensor]:
33+
if node is None:
34+
return None
3335
param = None
3436
if is_param(edge_program, node):
3537
param = get_param(edge_program, node)

0 commit comments

Comments
 (0)