Skip to content

Commit bd5752a

Browse files
authored
Fix ChannelsLastTaggedReshapePass crash on ops with list-typed args (#18958)
The `input_to_nhwc` method crashes with `AttributeError: 'immutable_list' object has no attribute 'graph'` when lowering dynamically quantized models containing ops like `cat` whose first argument is a list of tensors. Two fixes: 1. Add `isinstance(args[0], torch.fx.Node)` guard in the dynamic input trace-back loop to stop when args[0] is not a Node. 2. Handle list-typed `node.args[0]` at the call site by iterating over each element and converting them individually to NHWC. Fixes #18944
1 parent 9e36d62 commit bd5752a

2 files changed

Lines changed: 44 additions & 3 deletions

File tree

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,11 @@ def input_to_nhwc(
398398
is_dynamic_input = is_dynamic_qdq(input_node)
399399

400400
if is_dynamic_input:
401-
# Trace back to original source node
402-
while getattr(input_node, "args", None):
401+
# Trace back to original source node. Stop if args[0] is not
402+
# a Node (e.g., immutable_list from cat).
403+
while getattr(input_node, "args", None) and isinstance(
404+
input_node.args[0], torch.fx.Node
405+
):
403406
input_node = input_node.args[0]
404407

405408
with graph_module.graph.inserting_after(input_node):
@@ -505,7 +508,13 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
505508
elif self.requires_nhwc_input(node):
506509
# Nodes which enter this branch are ones that require their
507510
# first input to be nhwc. This makes this node's output nhwc too
508-
self.input_to_nhwc(graph_module, node.args[0], node)
511+
if isinstance(node.args[0], (list, tuple)):
512+
# Ops like cat have a list of tensors as args[0].
513+
for arg in node.args[0]:
514+
if isinstance(arg, torch.fx.Node):
515+
self.input_to_nhwc(graph_module, arg, node)
516+
else:
517+
self.input_to_nhwc(graph_module, node.args[0], node)
509518
for input_node in node.all_input_nodes[1:]:
510519
if (
511520
input_node.op == "placeholder"

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,3 +633,35 @@ def forward(self, x):
633633
]
634634
self.assertEqual(1, len(view_nodes))
635635
self.assertTrue(ChannelsLastTaggedReshapePass(None).is_nchw_node(view_nodes[0]))
636+
637+
class ConvCat(torch.nn.Module):
638+
def __init__(self):
639+
super().__init__()
640+
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
641+
self.conv2 = torch.nn.Conv2d(3, 16, 3, padding=1)
642+
643+
def forward(self, x):
644+
return torch.cat([self.conv1(x), self.conv2(x)], dim=1)
645+
646+
def test_fp32_conv_cat_immutable_list(self):
647+
model = self.ConvCat().eval()
648+
x = torch.randn(1, 3, 8, 8)
649+
self.run_tester(model, (x,))
650+
651+
def test_dq_conv_cat_immutable_list(self):
652+
model = self.ConvCat().eval()
653+
x = torch.randn(1, 3, 8, 8)
654+
(
655+
Tester(model, (x,))
656+
.quantize(
657+
Quantize(
658+
quantization_config=get_symmetric_quantization_config(
659+
is_dynamic=True
660+
)
661+
)
662+
)
663+
.export()
664+
.to_edge()
665+
.run_passes(self.PassStage)
666+
.run_method_and_compare_outputs()
667+
)

0 commit comments

Comments
 (0)