diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index df80749e72f..85e3476cad3 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -174,12 +174,21 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Check if the output is also quantized (q → dq → linear → q pattern) # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + # Due to decomposition of aten.linear for 3D+ inputs, there may be a + # view_copy between the mm output and the quantize node. self.quantize_output_node = None self.output_scales_node = None self.output_zeros_node = None self.relu_node = None + self.output_view_copy_node = None if len(self.output_node.users) == 1: cur_node = list(self.output_node.users)[0] + # Skip potential view_copy between linear and output quantize + if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1: + self.output_view_copy_node = cur_node + self.all_nodes.append(self.output_view_copy_node) + self.output_node = self.output_view_copy_node + cur_node = list(cur_node.users)[0] if cur_node.target == exir_ops.edge.aten.relu.default: self.relu_node = cur_node if len(cur_node.users) == 1: