Skip to content

Commit 2ff32a3

Browse files
authored
Merge branch 'main' into docathon/contributing-cpp-tests
2 parents b448787 + a49171d commit 2ff32a3

38 files changed

Lines changed: 1984 additions & 55 deletions

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ def get_arg_tensor_source_repset(
252252
"""
253253
arg_node = op_node.args[arg_i]
254254

255-
# For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does
255+
# For non-tensor arguments, return ANY_STORAGE_INCL_PACKED_INT8 so that the respset does
256256
# not appear to be empty.
257257
if not utils.is_tensor_arg_node(arg_node):
258-
return utils.ALL_STORAGES_REPSET
258+
return utils.ANY_STORAGE_INCL_PACKED_INT8
259259

260260
# Special case for cat - use the first tensor in the list as representative
261261
if isinstance(arg_node, list):

backends/vulkan/custom_ops_lib.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,39 @@ def q8ta_relu_impl(
954954
lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd")
955955
q8ta_relu_op = getattr(getattr(torch.ops, namespace), name)
956956

957+
###########################
958+
## q8ta_pixel_shuffle ##
959+
###########################
960+
961+
962+
def q8ta_pixel_shuffle_impl(
963+
input: torch.Tensor,
964+
input_scale: float,
965+
input_zero_point: int,
966+
output_inv_scale: float,
967+
output_zero_point: int,
968+
upscale_factor: int,
969+
):
970+
# Reference Python impl for op registration. The runtime kernel does a
971+
# fused byte-shuffle (and optional requantize when scales differ).
972+
output_scale = 1.0 / output_inv_scale
973+
dequant = torch.ops.quantized_decomposed.dequantize_per_tensor(
974+
input, input_scale, input_zero_point, -128, 127, input.dtype
975+
)
976+
shuffled = torch.nn.functional.pixel_shuffle(dequant, upscale_factor)
977+
requantized = torch.ops.quantized_decomposed.quantize_per_tensor(
978+
shuffled, output_scale, output_zero_point, -128, 127, torch.int8
979+
)
980+
return requantized
981+
982+
983+
name = "q8ta_pixel_shuffle"
984+
lib.define(
985+
f"{name}(Tensor input, float input_scale, int input_zero_point, float output_inv_scale, int output_zero_point, int upscale_factor) -> Tensor"
986+
)
987+
lib.impl(name, q8ta_pixel_shuffle_impl, "CompositeExplicitAutograd")
988+
q8ta_pixel_shuffle_op = getattr(getattr(torch.ops, namespace), name)
989+
957990
########################
958991
## embedding_q4gsw ##
959992
########################

backends/vulkan/op_registry.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,25 @@ def register_q8ta_relu():
611611
)
612612

613613

614+
# =============================================================================
615+
# Q8taPixelShuffle.cpp
616+
# =============================================================================
617+
618+
619+
@update_features(exir_ops.edge.et_vk.q8ta_pixel_shuffle.default)
620+
def register_q8ta_pixel_shuffle():
621+
# The fused kernel is restricted to the channels-packed family
622+
# (PACKED_INT8_4W4C, PACKED_INT8_4C1W, PACKED_INT8_CONV2D), all of which
623+
# share packed_dim=C. See add_q8ta_pixel_shuffle_node in Q8taPixelShuffle.cpp
624+
# for the runtime assertion. The surrounding q8ta_conv2d ops produce
625+
# PACKED_INT8_4W4C on this model, so the partitioner can route through this
626+
# op without inserting layout-transition q8ta_clone dispatches.
627+
return OpFeatures(
628+
inputs_storage=utils.PACKED_INT8_CHANNELS_PACKED_BUFFER,
629+
supports_resize=True,
630+
)
631+
632+
614633
# =============================================================================
615634
# =============================================================================
616635

@@ -1158,7 +1177,7 @@ def register_permute_copy():
11581177
@update_features(exir_ops.edge.aten.view_copy.default)
11591178
def register_view_copy():
11601179
return OpFeatures(
1161-
inputs_storage=utils.ANY_STORAGE,
1180+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
11621181
inputs_dtypes=utils.FP_INT_BOOL_T,
11631182
supports_resize=True,
11641183
supports_highdim=True,
@@ -1213,7 +1232,7 @@ def register_unsqueeze_copy():
12131232
@update_features(exir_ops.edge.aten.clone.default)
12141233
def register_clone():
12151234
return OpFeatures(
1216-
inputs_storage=utils.ANY_STORAGE,
1235+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12171236
inputs_dtypes=utils.FP_INT_BOOL_T,
12181237
supports_resize=True,
12191238
supports_highdim=True,
@@ -1223,7 +1242,7 @@ def register_clone():
12231242
@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
12241243
def register_clone_dim_order():
12251244
return OpFeatures(
1226-
inputs_storage=utils.ANY_STORAGE,
1245+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12271246
inputs_dtypes=utils.FP_INT_BOOL_T,
12281247
supports_resize=True,
12291248
supports_highdim=True,
@@ -1237,7 +1256,7 @@ def register_clone_dim_order():
12371256
@update_features(exir_ops.edge.aten.alias_copy.default)
12381257
def register_alias_copy():
12391258
return OpFeatures(
1240-
inputs_storage=utils.ANY_STORAGE,
1259+
inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8,
12411260
inputs_dtypes=utils.FP_INT_BOOL_T,
12421261
supports_resize=True,
12431262
supports_highdim=True,
@@ -1505,6 +1524,20 @@ def register_upsample_cpp_ops():
15051524
)
15061525

15071526

1527+
# =============================================================================
1528+
# PixelShuffle.cpp
1529+
# =============================================================================
1530+
1531+
1532+
@update_features(exir_ops.edge.aten.pixel_shuffle.default)
1533+
def register_pixel_shuffle():
1534+
return OpFeatures(
1535+
inputs_storage=utils.ANY_STORAGE,
1536+
inputs_dtypes=utils.FP_T,
1537+
supports_resize=True,
1538+
)
1539+
1540+
15081541
# =============================================================================
15091542
# GridPriors.cpp
15101543
# =============================================================================

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ops_not_to_decompose = [
4949
torch.ops.aten.hardswish.default,
5050
torch.ops.aten.upsample_nearest2d.vec,
51+
torch.ops.aten.pixel_shuffle.default,
5152
]
5253

5354
logger: logging.Logger = logging.getLogger("")

backends/vulkan/patterns/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fbcode_target(_kind = runtime.python_library,
1515
"quantized_linear.py",
1616
"quantized_convolution.py",
1717
"quantized_binary.py",
18+
"quantized_pixel_shuffle.py",
1819
"quantized_unary.py",
1920
"rms_norm.py",
2021
"sdpa.py",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.quantized_linear # noqa
1616

17+
import executorch.backends.vulkan.patterns.quantized_pixel_shuffle # noqa
18+
1719
import executorch.backends.vulkan.patterns.quantized_unary # noqa
1820

1921
import executorch.backends.vulkan.patterns.rms_norm # noqa
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List, Optional, Set
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
22+
from torch.fx.node import Argument
23+
24+
25+
# Set of ops that act as no-ops on values (i.e. clones / dim_order copies that
26+
# preserve dtype and shape). The matcher transparently skips these between the
27+
# dequantize, pixel_shuffle, and quantize nodes.
28+
_NOOP_PASSTHROUGH_TARGETS: Set[object] = {
29+
exir_ops.edge.aten.clone.default,
30+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
31+
}
32+
33+
34+
def _is_noop_passthrough(node: torch.fx.Node) -> bool:
35+
return node.op == "call_function" and node.target in _NOOP_PASSTHROUGH_TARGETS
36+
37+
38+
def _skip_passthrough_user(
39+
node: torch.fx.Node, collected: List[torch.fx.Node]
40+
) -> Optional[torch.fx.Node]:
41+
"""Given `node`, advance to its next non-passthrough user, walking through
42+
any chain of clone/dim_order_copy ops in between (collecting them in
43+
`collected`). Returns None if `node` has not exactly one user, or if any
44+
intermediate passthrough has more than one user."""
45+
if len(node.users) != 1:
46+
return None
47+
cur = next(iter(node.users))
48+
while _is_noop_passthrough(cur):
49+
collected.append(cur)
50+
if len(cur.users) != 1:
51+
return None
52+
cur = next(iter(cur.users))
53+
return cur
54+
55+
56+
class QuantizedPixelShuffleMatch(PatternMatch):
57+
"""
58+
Matches an un-decomposed PixelShuffle wrapped between a quant/dequant pair:
59+
60+
q8ta_dequantize_per_tensor (int8 -> fp32)
61+
[optional] clone / _clone_dim_order
62+
aten.pixel_shuffle.default (upscale_factor = r)
63+
[optional] clone / _clone_dim_order
64+
q8ta_quantize_per_tensor (fp32 -> int8)
65+
66+
The anchor is the dequantize node since it is a unique entry point.
67+
68+
This relies on the partitioner's `ops_to_not_decompose()` hook preserving
69+
`aten.pixel_shuffle.default` through edge lowering, so we do not need to
70+
re-detect the decomposed view -> permute -> view pattern.
71+
"""
72+
73+
def __init__(self, dequantize_node: torch.fx.Node) -> None:
74+
self.anchor_node: torch.fx.Node = dequantize_node
75+
self.match_found: bool = False
76+
self.all_nodes: List[torch.fx.Node] = [dequantize_node]
77+
78+
# Validate the dequantize node is one of the quant decomposed ops.
79+
if not utils.is_dequant_node(dequantize_node):
80+
return
81+
82+
# Walk forward to the pixel_shuffle node (skipping any clones).
83+
pixel_shuffle_node = _skip_passthrough_user(dequantize_node, self.all_nodes)
84+
if pixel_shuffle_node is None:
85+
return
86+
if pixel_shuffle_node.op != "call_function":
87+
return
88+
if pixel_shuffle_node.target != exir_ops.edge.aten.pixel_shuffle.default:
89+
return
90+
91+
# Walk forward to the quantize node (skipping any clones).
92+
quantize_node = _skip_passthrough_user(pixel_shuffle_node, self.all_nodes)
93+
if quantize_node is None or not utils.is_quant_node(quantize_node):
94+
return
95+
96+
# pixel_shuffle args are (input, upscale_factor).
97+
if len(pixel_shuffle_node.args) < 2:
98+
return
99+
upscale_factor = pixel_shuffle_node.args[1]
100+
if not isinstance(upscale_factor, int):
101+
return
102+
103+
# Capture the nodes and quant params we need for the replacement.
104+
self.dequantize_input_node = dequantize_node
105+
self.pixel_shuffle_node: torch.fx.Node = pixel_shuffle_node
106+
self.quantize_output_node: torch.fx.Node = quantize_node
107+
108+
self.input_int8_node: Argument = dequantize_node.args[0]
109+
self.input_scales_node: Argument = dequantize_node.args[1]
110+
self.input_zeros_node: Argument = dequantize_node.args[2]
111+
self.output_scales_node: Argument = quantize_node.args[1]
112+
self.output_zeros_node: Argument = quantize_node.args[2]
113+
self.upscale_factor: int = upscale_factor
114+
115+
self.all_nodes.extend([pixel_shuffle_node, quantize_node])
116+
# The replacement target replaces uses of the quantize node.
117+
self.output_node: torch.fx.Node = quantize_node
118+
119+
self.match_found = True
120+
121+
122+
@register_pattern_detector("quantized_pixel_shuffle")
123+
def find_quantized_pixel_shuffle_pattern(
124+
node: torch.fx.Node,
125+
) -> Optional[QuantizedPixelShuffleMatch]:
126+
if node.op != "call_function":
127+
return None
128+
if not utils.is_dequant_node(node):
129+
return None
130+
matched = QuantizedPixelShuffleMatch(node)
131+
if matched.match_found:
132+
return matched
133+
return None
134+
135+
136+
@register_pattern_replacement("quantized_pixel_shuffle")
137+
def make_quantized_pixel_shuffle_custom_op(
138+
ep: ExportedProgram,
139+
graph_module: torch.fx.GraphModule,
140+
match: QuantizedPixelShuffleMatch,
141+
) -> None:
142+
op_target = exir_ops.edge.et_vk.q8ta_pixel_shuffle.default
143+
144+
# The fused op takes the *inverse* of the output scale to match the
145+
# runtime kernel's expectation.
146+
output_scale = match.output_scales_node
147+
inv_output_scale: object
148+
if isinstance(output_scale, (int, float)):
149+
inv_output_scale = 1.0 / float(output_scale)
150+
else:
151+
# Intentional bail-out at the replacement step (not a TODO). The
152+
# matcher deliberately does not pre-validate that the output scale is
153+
# scalar because every observed quantize_per_tensor in real models has
154+
# a baked-in float scale; if that assumption breaks, we want a loud
155+
# failure here at fusion time rather than a silent miscompile.
156+
# If the output scale is a graph node (rare for static per-tensor
157+
# quant, but possible), insert a reciprocal computation. For all the
158+
# cases observed in the model the scales are baked-in floats, so we
159+
# raise here to make the failure visible rather than producing a
160+
# silent miscompile.
161+
raise NotImplementedError(
162+
"quantized_pixel_shuffle pattern only supports scalar output scales"
163+
)
164+
165+
with graph_module.graph.inserting_before(match.output_node):
166+
new_node = graph_module.graph.create_node(
167+
"call_function",
168+
op_target,
169+
args=(
170+
match.input_int8_node,
171+
match.input_scales_node,
172+
match.input_zeros_node,
173+
inv_output_scale,
174+
match.output_zeros_node,
175+
match.upscale_factor,
176+
),
177+
)
178+
179+
new_node.meta["val"] = match.output_node.meta["val"]
180+
match.quantize_output_node.replace_all_uses_with(new_node)

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ vkapi::DescriptorSet Context::get_descriptor_set(
149149

150150
spec_constants.append(additional_constants);
151151

152+
const uint32_t resolved_required_subgroup_size =
153+
vkapi::resolve_required_subgroup_size(shader_descriptor, adapter_p_);
154+
152155
VkPipeline pipeline = pipeline_cache().retrieve(
153156
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
154157
shader_cache().retrieve(shader_descriptor),
155-
spec_constants});
158+
spec_constants,
159+
resolved_required_subgroup_size});
156160

157161
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
158162

@@ -315,8 +319,14 @@ VkPipeline Context::get_shader_pipeline(
315319

316320
spec_constants.append(additional_constants);
317321

322+
const uint32_t resolved_required_subgroup_size =
323+
vkapi::resolve_required_subgroup_size(shader, adapter_p_);
324+
318325
VkPipeline pipeline = pipeline_cache().retrieve(
319-
{pipeline_layout, shader_cache().retrieve(shader), spec_constants});
326+
{pipeline_layout,
327+
shader_cache().retrieve(shader),
328+
spec_constants,
329+
resolved_required_subgroup_size});
320330

321331
return pipeline;
322332
}

0 commit comments

Comments
 (0)