Skip to content

Commit a49171d

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][q8ta_pixel_shuffle] Add fused PixelShuffle custom op for channels-packed int8 tensors
Pull Request resolved: #19397 A RefineNet segmentation model spends ~860 us (~17% of inference) on the textbook decomposed PyTorch PixelShuffle chain (q8ta_dequantize -> view -> permute -> view -> q8ta_quantize) repeated four times in the segmentation head. This is wasteful: it materializes three buffers and round-trips through fp32 just to perform what is fundamentally a byte permutation on an int8 tensor. This diff introduces et_vk.q8ta_pixel_shuffle.default, a single fused kernel that operates directly on int8x4 packed buffers. Each thread writes one output int32 word (= 4 consecutive output channels at one (n, oh, ow) spatial position). Dispatch is 1D over total output int words, sized as N * div_up_4(C_out) * H_out * W_out with a 64-thread local workgroup. The four channel lanes inside an output int come from four different input int words (input channels are spaced by r*r), so each thread issues four input loads. The (oh % r, ow % r) -> input lane mapping is constant for a given thread because all four output lanes share (oh, ow). The first byte index is computed via the layout-aware helper tensor4d_idx_to_buf_idx; subsequent lanes derive their byte index by adding stride[packed_dim] * block_numel, a layout-only constant, so only one helper call is needed per thread. When input/output share scale and zero-point (the typical residual-path case), the requantize math is skipped and the kernel becomes a pure byte shuffle (selected via the passthrough push constant). The op accepts the channels-packed PACKED_INT8 family (PACKED_INT8_4W4C, PACKED_INT8_4C1W, PACKED_INT8_CONV2D) on both input and output. The partitioner routes the op into whichever channels-packed layout the surrounding q8ta_conv2d_pw / q8ta_add ops produce/consume (PACKED_INT8_4W4C on RefineNet). Restricting to the channels-packed family means the inner block axis is always C and the lane within an int word is constant per thread, which removes the need for layout-block-config spec consts in the shader. Rather than matching the decomposed view -> permute -> view chain after to_edge lowering, this diff preserves aten.pixel_shuffle.default through to_edge by adding it to the partitioner's ops_to_not_decompose list. The matcher then operates on the much simpler dq -> [clone] -> aten.pixel_shuffle.default -> [clone] -> q form. This keeps the matcher robust against edge-dialect / clone-insertion variations. Pieces in this diff: - Partitioner / fuser: - partitioner/vulkan_partitioner.py — adds aten.pixel_shuffle.default to ops_to_not_decompose so the framework preserves the op through to_edge lowering. - patterns/quantized_pixel_shuffle.py — detects dq -> [clone] -> aten.pixel_shuffle.default -> [clone] -> q and rewrites it to et_vk.q8ta_pixel_shuffle.default. Transparently skips clone / _clone_dim_order nodes between any pair of nodes. - Runtime kernel: - runtime/graph/ops/glsl/q8ta_pixel_shuffle.glsl + .yaml - runtime/graph/ops/impl/Q8taPixelShuffle.cpp + .h - Op definitions: - custom_ops_lib.py: register et_vk.q8ta_pixel_shuffle (Python op definition). - op_registry.py: inputs_storage = utils.PACKED_INT8_CHANNELS_PACKED_BUFFER. - Tests: - test/custom_ops/impl/TestQ8taPixelShuffle.cpp: test op that runs q -> [fused | unfused chain] -> dq, with selectable input/output int8 layouts via str args. The op accepts the channels-packed family; the layout_from_string helper currently exercises 4W4C. - test/custom_ops/test_q8ta_pixel_shuffle.cpp: 16 ACCU + 8 PERF cases (4 shapes x 2 qparam settings x 2 impl_selectors x 1 layout combination, 4W4C -> 4W4C). - test/test_vulkan_passes.py: positive and negative pattern-matcher unit tests against the un-decomposed form. ghstack-source-id: 379519848 @exported-using-ghexport Differential Revision: [D104099055](https://our.internmc.facebook.com/intern/diff/D104099055/)
1 parent 8e18287 commit a49171d

13 files changed

Lines changed: 1184 additions & 0 deletions

backends/vulkan/custom_ops_lib.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,39 @@ def q8ta_relu_impl(
954954
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
955955
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)
956956

957+
###########################
958+
## q8ta_pixel_shuffle ##
959+
###########################
960+
961+
962+
def q8ta_pixel_shuffle_impl(
963+
input: torch.Tensor,
964+
input_scale: float,
965+
input_zero_point: int,
966+
output_inv_scale: float,
967+
output_zero_point: int,
968+
upscale_factor: int,
969+
):
970+
# Reference Python impl for op registration. The runtime kernel does a
971+
# fused byte-shuffle (and optional requantize when scales differ).
972+
output_scale = 1.0 / output_inv_scale
973+
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
974+
input, input_scale, input_zero_point, -128, 127, input.dtype
975+
)
976+
shuffled = torch.nn.functional.pixel_shuffle(dequant, upscale_factor)
977+
requantized = torch.ops.quantized_decomposed.quantize_per_tensor(
978+
shuffled, output_scale, output_zero_point, -128, 127, torch.int8
979+
)
980+
return requantized
981+
982+
983+
name = "q8ta_pixel_shuffle"
984+
lib.define(
985+
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_inv_scale, int output_zero_point, int upscale_factor) -> Tensor"
986+
)
987+
lib.impl(name, q8ta_pixel_shuffle_impl, "CompositeExplicitAutograd")
988+
q8ta_pixel_shuffle_op = getattr(getattr(torch.ops, namespace), name)
989+
957990
########################
958991
## embedding_q4gsw ##
959992
########################

backends/vulkan/op_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,25 @@ def register_q8ta_relu():
611611
)
612612

613613

614+
# =============================================================================
615+
# Q8taPixelShuffle.cpp
616+
# =============================================================================
617+
618+
619+
@update_features(exir_ops.edge.et_vk.q8ta_pixel_shuffle.default)
620+
def register_q8ta_pixel_shuffle():
621+
# The fused kernel is restricted to the channels-packed family
622+
# (PACKED_INT8_4W4C, PACKED_INT8_4C1W, PACKED_INT8_CONV2D), all of which
623+
# share packed_dim=C. See add_q8ta_pixel_shuffle_node in Q8taPixelShuffle.cpp
624+
# for the runtime assertion. The surrounding q8ta_conv2d ops produce
625+
# PACKED_INT8_4W4C on this model, so the partitioner can route through this
626+
# op without inserting layout-transition q8ta_clone dispatches.
627+
return OpFeatures(
628+
inputs_storage=utils.PACKED_INT8_CHANNELS_PACKED_BUFFER,
629+
supports_resize=True,
630+
)
631+
632+
614633
# =============================================================================
615634
# =============================================================================
616635

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fbcode_target(_kind = runtime.python_library,
1515
"quantized_linear.py",
1616
"quantized_convolution.py",
1717
"quantized_binary.py",
18+
"quantized_pixel_shuffle.py",
1819
"quantized_unary.py",
1920
"rms_norm.py",
2021
"sdpa.py",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.quantized_linear # noqa
1616

17+
import executorch.backends.vulkan.patterns.quantized_pixel_shuffle # noqa
18+
1719
import executorch.backends.vulkan.patterns.quantized_unary # noqa
1820

1921
import executorch.backends.vulkan.patterns.rms_norm # noqa
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional, Set
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
22+
from torch.fx.node import Argument
23+
24+
25+
# Set of ops that act as no-ops on values (i.e. clones / dim_order copies that
26+
# preserve dtype and shape). The matcher transparently skips these between the
27+
# dequantize, pixel_shuffle, and quantize nodes.
28+
_NOOP_PASSTHROUGH_TARGETS: Set[object] = {
29+
exir_ops.edge.aten.clone.default,
30+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
31+
}
32+
33+
34+
def _is_noop_passthrough(node: torch.fx.Node) -> bool:
35+
return node.op == "call_function" and node.target in _NOOP_PASSTHROUGH_TARGETS
36+
37+
38+
def _skip_passthrough_user(
39+
node: torch.fx.Node, collected: List[torch.fx.Node]
40+
) -> Optional[torch.fx.Node]:
41+
"""Given `node`, advance to its next non-passthrough user, walking through
42+
any chain of clone/dim_order_copy ops in between (collecting them in
43+
`collected`). Returns None if `node` has not exactly one user, or if any
44+
intermediate passthrough has more than one user."""
45+
if len(node.users) != 1:
46+
return None
47+
cur = next(iter(node.users))
48+
while _is_noop_passthrough(cur):
49+
collected.append(cur)
50+
if len(cur.users) != 1:
51+
return None
52+
cur = next(iter(cur.users))
53+
return cur
54+
55+
56+
class QuantizedPixelShuffleMatch(PatternMatch):
57+
"""
58+
Matches an un-decomposed PixelShuffle wrapped between a quant/dequant pair:
59+
60+
q8ta_dequantize_per_tensor (int8 -> fp32)
61+
[optional] clone / _clone_dim_order
62+
aten.pixel_shuffle.default (upscale_factor = r)
63+
[optional] clone / _clone_dim_order
64+
q8ta_quantize_per_tensor (fp32 -> int8)
65+
66+
The anchor is the dequantize node since it is a unique entry point.
67+
68+
This relies on the partitioner's `ops_to_not_decompose()` hook preserving
69+
`aten.pixel_shuffle.default` through edge lowering, so we do not need to
70+
re-detect the decomposed view -> permute -> view pattern.
71+
"""
72+
73+
def __init__(self, dequantize_node: torch.fx.Node) -> None:
74+
self.anchor_node: torch.fx.Node = dequantize_node
75+
self.match_found: bool = False
76+
self.all_nodes: List[torch.fx.Node] = [dequantize_node]
77+
78+
# Validate the dequantize node is one of the quant decomposed ops.
79+
if not utils.is_dequant_node(dequantize_node):
80+
return
81+
82+
# Walk forward to the pixel_shuffle node (skipping any clones).
83+
pixel_shuffle_node = _skip_passthrough_user(dequantize_node, self.all_nodes)
84+
if pixel_shuffle_node is None:
85+
return
86+
if pixel_shuffle_node.op != "call_function":
87+
return
88+
if pixel_shuffle_node.target != exir_ops.edge.aten.pixel_shuffle.default:
89+
return
90+
91+
# Walk forward to the quantize node (skipping any clones).
92+
quantize_node = _skip_passthrough_user(pixel_shuffle_node, self.all_nodes)
93+
if quantize_node is None or not utils.is_quant_node(quantize_node):
94+
return
95+
96+
# pixel_shuffle args are (input, upscale_factor).
97+
if len(pixel_shuffle_node.args) < 2:
98+
return
99+
upscale_factor = pixel_shuffle_node.args[1]
100+
if not isinstance(upscale_factor, int):
101+
return
102+
103+
# Capture the nodes and quant params we need for the replacement.
104+
self.dequantize_input_node = dequantize_node
105+
self.pixel_shuffle_node: torch.fx.Node = pixel_shuffle_node
106+
self.quantize_output_node: torch.fx.Node = quantize_node
107+
108+
self.input_int8_node: Argument = dequantize_node.args[0]
109+
self.input_scales_node: Argument = dequantize_node.args[1]
110+
self.input_zeros_node: Argument = dequantize_node.args[2]
111+
self.output_scales_node: Argument = quantize_node.args[1]
112+
self.output_zeros_node: Argument = quantize_node.args[2]
113+
self.upscale_factor: int = upscale_factor
114+
115+
self.all_nodes.extend([pixel_shuffle_node, quantize_node])
116+
# The replacement target replaces uses of the quantize node.
117+
self.output_node: torch.fx.Node = quantize_node
118+
119+
self.match_found = True
120+
121+
122+
@register_pattern_detector("quantized_pixel_shuffle")
123+
def find_quantized_pixel_shuffle_pattern(
124+
node: torch.fx.Node,
125+
) -> Optional[QuantizedPixelShuffleMatch]:
126+
if node.op != "call_function":
127+
return None
128+
if not utils.is_dequant_node(node):
129+
return None
130+
matched = QuantizedPixelShuffleMatch(node)
131+
if matched.match_found:
132+
return matched
133+
return None
134+
135+
136+
@register_pattern_replacement("quantized_pixel_shuffle")
137+
def make_quantized_pixel_shuffle_custom_op(
138+
ep: ExportedProgram,
139+
graph_module: torch.fx.GraphModule,
140+
match: QuantizedPixelShuffleMatch,
141+
) -> None:
142+
op_target = exir_ops.edge.et_vk.q8ta_pixel_shuffle.default
143+
144+
# The fused op takes the *inverse* of the output scale to match the
145+
# runtime kernel's expectation.
146+
output_scale = match.output_scales_node
147+
inv_output_scale: object
148+
if isinstance(output_scale, (int, float)):
149+
inv_output_scale = 1.0 / float(output_scale)
150+
else:
151+
# Intentional bail-out at the replacement step (not a TODO). The
152+
# matcher deliberately does not pre-validate that the output scale is
153+
# scalar because every observed quantize_per_tensor in real models has
154+
# a baked-in float scale; if that assumption breaks, we want a loud
155+
# failure here at fusion time rather than a silent miscompile.
156+
# If the output scale is a graph node (rare for static per-tensor
157+
# quant, but possible), insert a reciprocal computation. For all the
158+
# cases observed in the model the scales are baked-in floats, so we
159+
# raise here to make the failure visible rather than producing a
160+
# silent miscompile.
161+
raise NotImplementedError(
162+
"quantized_pixel_shuffle pattern only supports scalar output scales"
163+
)
164+
165+
with graph_module.graph.inserting_before(match.output_node):
166+
new_node = graph_module.graph.create_node(
167+
"call_function",
168+
op_target,
169+
args=(
170+
match.input_int8_node,
171+
match.input_scales_node,
172+
match.input_zeros_node,
173+
inv_output_scale,
174+
match.output_zeros_node,
175+
match.upscale_factor,
176+
),
177+
)
178+
179+
new_node.meta["val"] = match.output_node.meta["val"]
180+
match.quantize_output_node.replace_all_uses_with(new_node)

0 commit comments

Comments
 (0)