diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index d3b44e0a9ca..f97053734f9 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -252,10 +252,10 @@ def get_arg_tensor_source_repset( """ arg_node = op_node.args[arg_i] - # For non-tensor arguments, return ALL_STORAGES_REPSET so that the respset does + # For non-tensor arguments, return ANY_STORAGE_INCL_PACKED_INT8 so that the respset does # not appear to be empty. if not utils.is_tensor_arg_node(arg_node): - return utils.ALL_STORAGES_REPSET + return utils.ANY_STORAGE_INCL_PACKED_INT8 # Special case for cat - use the first tensor in the list as representative if isinstance(arg_node, list): diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 4b1b02466ee..4364f67123d 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -954,6 +954,39 @@ def q8ta_relu_impl( lib.impl(name, q8ta_relu_impl, "CompositeExplicitAutograd") q8ta_relu_op = getattr(getattr(torch.ops, namespace), name) +########################### +## q8ta_pixel_shuffle ## +########################### + + +def q8ta_pixel_shuffle_impl( + input: torch.Tensor, + input_scale: float, + input_zero_point: int, + output_inv_scale: float, + output_zero_point: int, + upscale_factor: int, +): + # Reference Python impl for op registration. The runtime kernel does a + # fused byte-shuffle (and optional requantize when scales differ). + output_scale = 1.0 / output_inv_scale + dequant = torch.ops.quantized_decomposed.dequantize_per_tensor( + input, input_scale, input_zero_point, -128, 127, input.dtype + ) + shuffled = torch.nn.functional.pixel_shuffle(dequant, upscale_factor) + requantized = torch.ops.quantized_decomposed.quantize_per_tensor( + shuffled, output_scale, output_zero_point, -128, 127, torch.int8 + ) + return requantized + + +name = "q8ta_pixel_shuffle" +lib.define( + f"{name}(Tensor input, float input_scale, int input_zero_point, float output_inv_scale, int output_zero_point, int upscale_factor) -> Tensor" +) +lib.impl(name, q8ta_pixel_shuffle_impl, "CompositeExplicitAutograd") +q8ta_pixel_shuffle_op = getattr(getattr(torch.ops, namespace), name) + ######################## ## embedding_q4gsw ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9345f0a9090..87f7ea8b996 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -611,6 +611,25 @@ def register_q8ta_relu(): ) +# ============================================================================= +# Q8taPixelShuffle.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.et_vk.q8ta_pixel_shuffle.default) +def register_q8ta_pixel_shuffle(): + # The fused kernel is restricted to the channels-packed family + # (PACKED_INT8_4W4C, PACKED_INT8_4C1W, PACKED_INT8_CONV2D), all of which + # share packed_dim=C. See add_q8ta_pixel_shuffle_node in Q8taPixelShuffle.cpp + # for the runtime assertion. The surrounding q8ta_conv2d ops produce + # PACKED_INT8_4W4C on this model, so the partitioner can route through this + # op without inserting layout-transition q8ta_clone dispatches. + return OpFeatures( + inputs_storage=utils.PACKED_INT8_CHANNELS_PACKED_BUFFER, + supports_resize=True, + ) + + # ============================================================================= # ============================================================================= @@ -1158,7 +1177,7 @@ def register_permute_copy(): @update_features(exir_ops.edge.aten.view_copy.default) def register_view_copy(): return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, supports_highdim=True, @@ -1213,7 +1232,7 @@ def register_unsqueeze_copy(): @update_features(exir_ops.edge.aten.clone.default) def register_clone(): return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, supports_highdim=True, @@ -1223,7 +1242,7 @@ def register_clone(): @update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default) def register_clone_dim_order(): return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, supports_highdim=True, @@ -1237,7 +1256,7 @@ def register_clone_dim_order(): @update_features(exir_ops.edge.aten.alias_copy.default) def register_alias_copy(): return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_STORAGE_INCL_PACKED_INT8, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, supports_highdim=True, @@ -1505,6 +1524,20 @@ def register_upsample_cpp_ops(): ) +# ============================================================================= +# PixelShuffle.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.aten.pixel_shuffle.default) +def register_pixel_shuffle(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.FP_T, + supports_resize=True, + ) + + # ============================================================================= # GridPriors.cpp # ============================================================================= diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 4066da70fe5..60b4c3346f3 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -48,6 +48,7 @@ ops_not_to_decompose = [ torch.ops.aten.hardswish.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.pixel_shuffle.default, ] logger: logging.Logger = logging.getLogger("") diff --git a/backends/vulkan/patterns/BUCK b/backends/vulkan/patterns/BUCK index 7fa132fd5cb..7803ba64f60 100644 --- a/backends/vulkan/patterns/BUCK +++ b/backends/vulkan/patterns/BUCK @@ -15,6 +15,7 @@ fbcode_target(_kind = runtime.python_library, "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "quantized_pixel_shuffle.py", "quantized_unary.py", "rms_norm.py", "sdpa.py", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 86fb82a03d2..68df9905671 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,8 @@ import executorch.backends.vulkan.patterns.quantized_linear # noqa +import executorch.backends.vulkan.patterns.quantized_pixel_shuffle # noqa + import executorch.backends.vulkan.patterns.quantized_unary # noqa import executorch.backends.vulkan.patterns.rms_norm # noqa diff --git a/backends/vulkan/patterns/quantized_pixel_shuffle.py b/backends/vulkan/patterns/quantized_pixel_shuffle.py new file mode 100644 index 00000000000..1a180863935 --- /dev/null +++ b/backends/vulkan/patterns/quantized_pixel_shuffle.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Set + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + +from torch.fx.node import Argument + + +# Set of ops that act as no-ops on values (i.e. clones / dim_order copies that +# preserve dtype and shape). The matcher transparently skips these between the +# dequantize, pixel_shuffle, and quantize nodes. +_NOOP_PASSTHROUGH_TARGETS: Set[object] = { + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, +} + + +def _is_noop_passthrough(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in _NOOP_PASSTHROUGH_TARGETS + + +def _skip_passthrough_user( + node: torch.fx.Node, collected: List[torch.fx.Node] +) -> Optional[torch.fx.Node]: + """Given `node`, advance to its next non-passthrough user, walking through + any chain of clone/dim_order_copy ops in between (collecting them in + `collected`). Returns None if `node` has not exactly one user, or if any + intermediate passthrough has more than one user.""" + if len(node.users) != 1: + return None + cur = next(iter(node.users)) + while _is_noop_passthrough(cur): + collected.append(cur) + if len(cur.users) != 1: + return None + cur = next(iter(cur.users)) + return cur + + +class QuantizedPixelShuffleMatch(PatternMatch): + """ + Matches an un-decomposed PixelShuffle wrapped between a quant/dequant pair: + + q8ta_dequantize_per_tensor (int8 -> fp32) + [optional] clone / _clone_dim_order + aten.pixel_shuffle.default (upscale_factor = r) + [optional] clone / _clone_dim_order + q8ta_quantize_per_tensor (fp32 -> int8) + + The anchor is the dequantize node since it is a unique entry point. + + This relies on the partitioner's `ops_to_not_decompose()` hook preserving + `aten.pixel_shuffle.default` through edge lowering, so we do not need to + re-detect the decomposed view -> permute -> view pattern. + """ + + def __init__(self, dequantize_node: torch.fx.Node) -> None: + self.anchor_node: torch.fx.Node = dequantize_node + self.match_found: bool = False + self.all_nodes: List[torch.fx.Node] = [dequantize_node] + + # Validate the dequantize node is one of the quant decomposed ops. + if not utils.is_dequant_node(dequantize_node): + return + + # Walk forward to the pixel_shuffle node (skipping any clones). + pixel_shuffle_node = _skip_passthrough_user(dequantize_node, self.all_nodes) + if pixel_shuffle_node is None: + return + if pixel_shuffle_node.op != "call_function": + return + if pixel_shuffle_node.target != exir_ops.edge.aten.pixel_shuffle.default: + return + + # Walk forward to the quantize node (skipping any clones). + quantize_node = _skip_passthrough_user(pixel_shuffle_node, self.all_nodes) + if quantize_node is None or not utils.is_quant_node(quantize_node): + return + + # pixel_shuffle args are (input, upscale_factor). + if len(pixel_shuffle_node.args) < 2: + return + upscale_factor = pixel_shuffle_node.args[1] + if not isinstance(upscale_factor, int): + return + + # Capture the nodes and quant params we need for the replacement. + self.dequantize_input_node = dequantize_node + self.pixel_shuffle_node: torch.fx.Node = pixel_shuffle_node + self.quantize_output_node: torch.fx.Node = quantize_node + + self.input_int8_node: Argument = dequantize_node.args[0] + self.input_scales_node: Argument = dequantize_node.args[1] + self.input_zeros_node: Argument = dequantize_node.args[2] + self.output_scales_node: Argument = quantize_node.args[1] + self.output_zeros_node: Argument = quantize_node.args[2] + self.upscale_factor: int = upscale_factor + + self.all_nodes.extend([pixel_shuffle_node, quantize_node]) + # The replacement target replaces uses of the quantize node. + self.output_node: torch.fx.Node = quantize_node + + self.match_found = True + + +@register_pattern_detector("quantized_pixel_shuffle") +def find_quantized_pixel_shuffle_pattern( + node: torch.fx.Node, +) -> Optional[QuantizedPixelShuffleMatch]: + if node.op != "call_function": + return None + if not utils.is_dequant_node(node): + return None + matched = QuantizedPixelShuffleMatch(node) + if matched.match_found: + return matched + return None + + +@register_pattern_replacement("quantized_pixel_shuffle") +def make_quantized_pixel_shuffle_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedPixelShuffleMatch, +) -> None: + op_target = exir_ops.edge.et_vk.q8ta_pixel_shuffle.default + + # The fused op takes the *inverse* of the output scale to match the + # runtime kernel's expectation. + output_scale = match.output_scales_node + inv_output_scale: object + if isinstance(output_scale, (int, float)): + inv_output_scale = 1.0 / float(output_scale) + else: + # Intentional bail-out at the replacement step (not a TODO). The + # matcher deliberately does not pre-validate that the output scale is + # scalar because every observed quantize_per_tensor in real models has + # a baked-in float scale; if that assumption breaks, we want a loud + # failure here at fusion time rather than a silent miscompile. + # If the output scale is a graph node (rare for static per-tensor + # quant, but possible), insert a reciprocal computation. For all the + # cases observed in the model the scales are baked-in floats, so we + # raise here to make the failure visible rather than producing a + # silent miscompile. + raise NotImplementedError( + "quantized_pixel_shuffle pattern only supports scalar output scales" + ) + + with graph_module.graph.inserting_before(match.output_node): + new_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.input_int8_node, + match.input_scales_node, + match.input_zeros_node, + inv_output_scale, + match.output_zeros_node, + match.upscale_factor, + ), + ) + + new_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(new_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.glsl new file mode 100644 index 00000000000..196555a279f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.glsl @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_outp", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "upscale_factor", "1")} + +/* + * pixel_shuffle: rearranges (N, C*r*r, H, W) -> (N, C, H*r, W*r). + * + * For output element at NCHW index (n, c, h_out, w_out): + * h_in = h_out / r + * w_in = w_out / r + * c_in = c * r * r + (h_out % r) * r + (w_out % r) + * + * The W, H, C dims correspond to NCHW indices [3], [2], [1] when ndim == 4 + * (ndim - 1, ndim - 2, ndim - 3 in general). We use NCHW dim numbering so the + * mapping is independent of the tensor's memory layout. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + TensorIndex outp_tidx = linear_idx_to_tensor_idx(outp, outp_bufi); + + // NCHW dim indices for W (last), H (second-last), C (third-last). + const int nd = int_ndim(outp); + const int w_dim_nchw = nd - 1; + const int h_dim_nchw = nd - 2; + const int c_dim_nchw = nd - 3; + + // Convert NCHW dim index to WHCN dim index expected by indexing helpers, + // where the "dim" parameter is ndim - 1 - nchw_dim (i.e. logical ordering + // matching strides). The tidx.data array stores indices in WHCN order: + // tidx.data[0] = W, tidx.data[1] = H, tidx.data[2] = C, tidx.data[3] = N. + const int r = upscale_factor; + + const uint w_out = idx_at(outp_tidx, 0); + const uint h_out = idx_at(outp_tidx, 1); + const uint c_out = idx_at(outp_tidx, 2); + + const uint w_in = w_out / uint(r); + const uint h_in = h_out / uint(r); + const uint c_in = c_out * uint(r) * uint(r) + + (h_out % uint(r)) * uint(r) + (w_out % uint(r)); + + TensorIndex inp_tidx = outp_tidx; + inp_tidx.data[0][0] = w_in; + inp_tidx.data[0][1] = h_in; + inp_tidx.data[0][2] = c_in; + + const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + + t_outp[outp_bufi] = t_inp[inp_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.yaml new file mode 100644 index 00000000000..610fde0beae --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pixel_shuffle_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: pixel_shuffle_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.glsl new file mode 100644 index 00000000000..46b20783bc4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.glsl @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "upscale_factor", "1")} + +const int out_packed_dim = get_packed_dim(out_layout); + +/* + * pixel_shuffle: rearranges (N, C*r*r, H, W) -> (N, C, H*r, W*r). + * + * For output element at NCHW index (n, c, h_out, w_out): + * w_in = w_out / r + * h_in = h_out / r + * c_in = c * r * r + (h_out % r) * r + (w_out % r) + * + * Each thread writes one output texel of 4 components along the packed dim. + * Each component may map to a different input texel, so we resolve per- + * component and use texelFetch on the input. + */ +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = + texture_pos_to_tensor4d_idx_simple(outp, out_pos, out_layout); + + // safe_idx() avoids dynamic UBO-vector indexing, which crashes Adreno 740. + // The output may not span a full block of 4 along the packed dim if the + // packed-dim size is not a multiple of 4, so clamp the loop. + const int limit = min( + 4, + safe_idx(outp.sizes, out_packed_dim) - + safe_idx(out_tidx.data, out_packed_dim)); + + const int r = upscale_factor; + + VEC4_T out_texel = VEC4_T(0); + for (int comp = 0; comp < 4; comp++) { + if (comp >= limit) { + break; + } + + // Build the per-component output tensor index. tidx.data is a local + // ivec4 in WHCN order ([0]=W, [1]=H, [2]=C, [3]=N), so dynamic indexing + // here is safe (not UBO-backed). + TensorIndex4D out_tidx_c = out_tidx; + safe_set( + out_tidx_c.data, + out_packed_dim, + safe_idx(out_tidx.data, out_packed_dim) + comp); + + const int w_out = out_tidx_c.data.x; + const int h_out = out_tidx_c.data.y; + const int c_out = out_tidx_c.data.z; + + const int w_in = w_out / r; + const int h_in = h_out / r; + const int c_in = c_out * r * r + (h_out % r) * r + (w_out % r); + + TensorIndex4D in_tidx; + in_tidx.data = ivec4(w_in, h_in, c_in, out_tidx_c.data.w); + + TextureElementIndex in_elem = + tensor4d_idx_to_texture_element_idx_simple(inp, in_tidx, in_layout); + VEC4_T in_texel = texelFetch(t_in, in_elem.pos, 0); + out_texel[comp] = in_texel[in_elem.comp]; + } + + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.yaml new file mode 100644 index 00000000000..45dfe5d19d9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/pixel_shuffle_texture.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +pixel_shuffle_texture: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: pixel_shuffle_texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl index fc063579c45..9f60bea9948 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_pw.glsl @@ -22,16 +22,18 @@ $if USE_INT8_DOT_PRODUCT_EXT == 1: ${define_active_storage_type("buffer")} +// Each thread computes a TILE_M (width) x TILE_N (output channel) output block, +// using an int32 accumulator tile. // corresponds to input/output width dim #define TILE_M4 1 // corresponds to input channels dim #define TILE_K4 1 // corresponds to output channels dim -#define TILE_N4 2 +#define TILE_N4 1 #define TILE_M 4 #define TILE_K 4 -#define TILE_N 8 +#define TILE_N 4 layout(std430) buffer; @@ -86,9 +88,9 @@ int compute_outp_buffer_idx( } void main() { - // Thread mapping: each thread handles TILE_M (4) widths × TILE_N (8) output channels - // gl_GlobalInvocationID.x → output channel blocks (TILE_N4 = 2 blocks of 4 channels) - // gl_GlobalInvocationID.y → width blocks (TILE_M4 = 1 block of 4 widths) + // Thread mapping: each thread handles TILE_M widths x TILE_N output channels. + // gl_GlobalInvocationID.x -> output channel blocks. + // gl_GlobalInvocationID.y -> width blocks. // gl_GlobalInvocationID.z → batch (or height * batch combined) const int oc_block_idx = int(gl_GlobalInvocationID.x) * TILE_N4; const int ow_block_idx = int(gl_GlobalInvocationID.y) * TILE_M4; @@ -137,11 +139,11 @@ void main() { // Main accumulation loop over K dimension for (int k4 = 0; k4 < K4_per_group; k4++) { - // Load packed int8 input tile (TILE_M4=1, TILE_K4=1) + // Load the packed int8 input tile for the current width and K sub-block. // Each int contains 4 packed int8s (one per width position in the tile) ivec4 int8_input_tile = t_packed_int8_input[input_idx]; - // Load int8 weight tile (TILE_K4=1, TILE_N4=2) + // Load the int8 weight tile for the current K and output-channel sub-block. ivec4 int8_weight_tile[TILE_N4]; [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { int8_weight_tile[n4] = texelFetch( diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.glsl new file mode 100644 index 00000000000..a2877f2b3ba --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.glsl @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output buffer: packed int8x4 (each int32 = 4 packed int8 along packed_dim) +${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")} +// Input buffer: packed int8x4 (each int32 = 4 packed int8 along packed_dim) +${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")} + +// Metadata for output and input tensors. Both are int8x4 packed buffers but +// may use different block layouts (e.g. PACKED_INT8_4W vs PACKED_INT8_4W4C). +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float scale_in; + float inv_scale_out; + int zp_in; + int zp_out; + int upscale_factor; + // Whether we can skip the requantize math and do a pure byte shuffle. + // Set to 1 by the host when (scale_in == scale_out) && (zp_in == zp_out). + int passthrough; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +/* + * PixelShuffle map: out[n, c_out, oh, ow] = + * in[n, c_out * r * r + (oh % r) * r + (ow % r), oh / r, ow / r]. + * + * Each thread produces one output int32 word (= 4 consecutive output channels + * at one (n, oh, ow) spatial position). Channels are the packed dim and + * packed_dim_block_size is 4, so writing one int word fills 4 channel lanes. + * + * The four channel lanes inside an output int come from four DIFFERENT input + * words (channels spaced by r*r in the input), so each thread issues 4 input + * loads. The (oh % r, ow % r) -> input lane mapping is fixed for a given + * thread because all four output lanes share (oh, ow). Out-of-output-bounds + * channel lanes (when C_out is not a multiple of 4) are zero-filled. + * + * Supported layouts: channels-packed family (PACKED_INT8_4W4C, PACKED_INT8_4C1W, + * PACKED_INT8_CONV2D). Layout-aware byte indexing is handled by + * tensor4d_idx_to_buf_idx (which consumes inp_layout / outp_layout). + */ + +// Apply requantize to an int8 lane value. +int requantize_lane(const int q_in) { + if (passthrough != 0) { + return q_in; + } + // Requantize: round((q_in - zp_in) * scale_in * inv_scale_out) + zp_out, + // clamped to int8. + float dq = float(q_in - zp_in) * scale_in; + int qv = int(round(dq * inv_scale_out)) + zp_out; + qv = clamp(qv, -128, 127); + return qv; +} + +void main() { + // Output sizes (WHCN order via meta.sizes[0]) + const int W_out = int(safe_idx(outp.sizes[0], 0)); + const int H_out = int(safe_idx(outp.sizes[0], 1)); + const int C_out = int(safe_idx(outp.sizes[0], 2)); + const int N = int(safe_idx(outp.sizes[0], 3)); + + // Input sizes + const int W_in = int(safe_idx(inp.sizes[0], 0)); + const int H_in = int(safe_idx(inp.sizes[0], 1)); + const int C_in = int(safe_idx(inp.sizes[0], 2)); + + // One thread per output int32 word: word covers 4 consecutive channels + // (along the packed dim) at one (n, oh, ow) spatial position. + const int C_words = div_up_4(C_out); + const int total_words = N * C_words * H_out * W_out; + const int thread_idx = int(gl_GlobalInvocationID.x); + if (thread_idx >= total_words) { + return; + } + + // Decode thread_idx in (W_out, H_out, C_words, N) order. + const int ow = thread_idx % W_out; + const int oh = (thread_idx / W_out) % H_out; + const int c_word = (thread_idx / (W_out * H_out)) % C_words; + const int n = thread_idx / (W_out * H_out * C_words); + const int c_out_base = c_word * 4; + + const int r = upscale_factor; + // (oh % r, ow % r) determines which input channel lane within the input + // word group of size r*r — constant for all 4 output channel lanes here. + const int offset = (oh % r) * r + (ow % r); + const int ih = oh / r; + const int iw = ow / r; + + const int c_in_first = c_out_base * r * r + offset; + + // Compute byte_idx for the first lane (i=0) via the layout-aware helper. + TensorIndex4D inp_idx; + inp_idx.data = ivec4(iw, ih, c_in_first, n); + const int byte_idx_first = tensor4d_idx_to_buf_idx(inp, inp_idx, inp_layout); + + // byte_stride between successive c_in advances of r*r = inner_block_size = 4. + // Each advance bumps the block-space C coord by 1, so byte_idx grows by + // stride[inner_dim] * block_numel. Both factors are layout-only, no second + // helper call needed. (Assumes r*r == inner_block_size == 4, enforced by the + // C++ dispatch's r==2 and packed_dim_block_size==4 asserts.) + const int byte_stride = + int(stride_at(inp, get_packed_dim(inp_layout))) * get_block_numel(inp_layout); + + // lane is the byte position within an int32 word, which equals + // (intra_block_idx % 4) since block_numel is a multiple of 4. And + // intra_block_idx % 4 == inner_offset == c_in_first % 4 == offset. + const int lane = offset; + + int packed_out = 0; + [[unroll]] for (int i = 0; i < 4; ++i) { + const int c_out_lane = c_out_base + i; + int q_out = 0; + if (c_out_lane < C_out) { + const int c_in = c_in_first + i * r * r; + int q_in; + if (iw >= W_in || ih >= H_in || c_in >= C_in) { + q_in = zp_in; + } else { + const int byte_idx = byte_idx_first + i * byte_stride; + const int word_idx = div_4(byte_idx); + const int packed = t_inp[word_idx]; + // Sign-extend from 8-bit + q_in = ((packed >> (lane * 8)) << 24) >> 24; + } + q_out = requantize_lane(q_in); + } + packed_out |= (q_out & 0xFF) << (i * 8); + } + + // Store the packed int directly. Output's packed dim is channels with + // block size 4, so the byte index for c_out_base aligns to a word boundary. + TensorIndex4D outp_idx; + outp_idx.data = ivec4(ow, oh, c_out_base, n); + const int outp_byte_idx = tensor4d_idx_to_buf_idx(outp, outp_idx, outp_layout); + t_outp[div_4(outp_byte_idx)] = packed_out; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.yaml new file mode 100644 index 00000000000..7aec357a7a2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_pixel_shuffle.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_pixel_shuffle: + parameter_names_with_default_values: + DTYPE: int + shader_variants: + - NAME: q8ta_pixel_shuffle diff --git a/backends/vulkan/runtime/graph/ops/impl/PixelShuffle.cpp b/backends/vulkan/runtime/graph/ops/impl/PixelShuffle.cpp new file mode 100644 index 00000000000..24c00b9d7af --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/PixelShuffle.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void resize_pixel_shuffle_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const ValueRef upscale_factor_ref = resize_args.at(0); + + const int64_t r = graph->extract_scalar(upscale_factor_ref); + + std::vector in_sizes = graph->sizes_of(in); + const int64_t ndim = static_cast(in_sizes.size()); + VK_CHECK_COND(ndim >= 3); + + std::vector out_sizes = in_sizes; + out_sizes.at(ndim - 3) = in_sizes.at(ndim - 3) / (r * r); + out_sizes.at(ndim - 2) = in_sizes.at(ndim - 2) * r; + out_sizes.at(ndim - 1) = in_sizes.at(ndim - 1) * r; + + graph->virtual_resize(out, out_sizes); +} + +void add_pixel_shuffle_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef upscale_factor_ref, + const ValueRef out) { + const int64_t r = graph.extract_scalar(upscale_factor_ref); + VK_CHECK_COND(r >= 1); + + const std::vector in_sizes = graph.sizes_of(in); + const int64_t ndim = static_cast(in_sizes.size()); + VK_CHECK_COND(ndim >= 3); + VK_CHECK_COND(in_sizes.at(ndim - 3) % (r * r) == 0); + + std::string kernel_name = "pixel_shuffle"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList ubos = {graph.meta_ubo(out), graph.meta_ubo(in)}; + + vkapi::SpecVarList spec_constants = { + graph.hashed_layout_of(out), + graph.hashed_layout_of(in), + static_cast(r)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + ubos, + // Push Constants + {}, + // Specialization Constants + spec_constants, + // Resize Args + {upscale_factor_ref}, + // Resizing Logic + resize_pixel_shuffle_node)); +} + +void pixel_shuffle(ComputeGraph& graph, const std::vector& args) { + const ValueRef in = args[0]; + const ValueRef upscale_factor_ref = args[1]; + const ValueRef out = args[2]; + add_pixel_shuffle_node(graph, in, upscale_factor_ref, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.pixel_shuffle.default, pixel_shuffle); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp index e27e0699dac..7a2380f728a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dPW.cpp @@ -33,19 +33,17 @@ utils::uvec3 pick_q8ta_conv2d_pw_global_wg_size( const uint32_t H = graph->size_at(-2, output); const uint32_t C = graph->size_at(-3, output); - // The 4W4C shader processes tiles of: - // - TILE_N4=2 groups of 4 output channels (8 channels per thread) - // - TILE_M4=1 groups of 4 widths (4 widths per thread) - // - 1 height per thread - constexpr uint32_t TILE_N4 = 2; + // Each thread covers a 4-width x 4-channel output block. + // Tile constants must match TILE_M4 / TILE_N4 in q8ta_conv2d_pw.glsl. + constexpr uint32_t TILE_N4 = 1; constexpr uint32_t TILE_M4 = 1; const uint32_t C4 = utils::div_up_4(C); const uint32_t W4 = utils::div_up_4(W); // Global workgroup size: - // x = output channels / (TILE_N4 * 4) = C4 / TILE_N4 - // y = width / (TILE_M4 * 4) = W4 / TILE_M4 + // x = output channels / (TILE_N4 * 4) = C4 / TILE_N4 = C4 + // y = width / (TILE_M4 * 4) = W4 / TILE_M4 = W4 // z = height return {utils::div_up(C4, TILE_N4), utils::div_up(W4, TILE_M4), H}; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.cpp new file mode 100644 index 00000000000..74712654fd4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.cpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +#include +#include + +namespace vkcompute { + +namespace { + +void resize_q8ta_pixel_shuffle_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + // resize_args[0] is the upscale factor packed into a ValueRef (an int). + const int32_t r = static_cast(resize_args.at(0)); + + const std::vector in_sizes = graph->sizes_of(in); + VK_CHECK_COND(in_sizes.size() == 4); + // Input is [N, C, H, W]; output is [N, C/r/r, H*r, W*r]. + std::vector out_sizes = in_sizes; + out_sizes[1] = in_sizes[1] / (r * r); + out_sizes[2] = in_sizes[2] * r; + out_sizes[3] = in_sizes[3] * r; + graph->virtual_resize(out, out_sizes); +} + +// Global wg picker: one thread per output int32 word. For a channels-packed +// int8x4 output with channel block size 4, the number of output int words is +// N * div_up_4(C_out) * H_out * W_out. +utils::uvec3 pick_q8ta_pixel_shuffle_global_wg( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const std::vector& sizes = graph->sizes_of(out); + const int64_t N = utils::val_at(-4, sizes); + const int64_t C = utils::val_at(-3, sizes); + const int64_t H = utils::val_at(-2, sizes); + const int64_t W = utils::val_at(-1, sizes); + const int64_t c_words = utils::div_up(C, int64_t(4)); + const uint32_t total_words = + utils::safe_downcast(N * c_words * H * W); + return {total_words, 1u, 1u}; +} + +utils::uvec3 pick_q8ta_pixel_shuffle_local_wg( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + // Linear (1D) dispatch: a flat 64-wide workgroup matches the pattern used + // by pick_square_local_wg_with_block_config in the linear case. + return {64u, 1u, 1u}; +} + +} // namespace + +// +// Dispatch +// + +void add_q8ta_pixel_shuffle_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_inv_scale, + const ValueRef output_zp, + const ValueRef upscale_factor, + const ValueRef packed_int8_output) { + // dtype must be the int8x4 packed type + VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); + VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + + // Both tensors must be buffer-backed + VK_CHECK_COND(graph.is_buffer_storage(packed_int8_input)); + VK_CHECK_COND(graph.is_buffer_storage(packed_int8_output)); + + // Tensors must be 4D (N, C, H, W). + VK_CHECK_COND(graph.dim_of(packed_int8_input) == 4); + VK_CHECK_COND(graph.dim_of(packed_int8_output) == 4); + + // Both tensors must use a channels-packed int8x4 layout (packed_dim=C with + // packed_dim_block_size=4). The supported layouts are PACKED_INT8_4W4C + // (outer block on W), PACKED_INT8_4C1W (no outer block), and + // PACKED_INT8_CONV2D. Each output thread writes one int word covering 4 + // consecutive channels at one (n, oh, ow) position. + const api::PackedDimInfo& in_info = + graph.packed_dim_info_of(packed_int8_input); + const api::PackedDimInfo& out_info = + graph.packed_dim_info_of(packed_int8_output); + VK_CHECK_COND(in_info.packed_dim_block_size == 4); + VK_CHECK_COND(out_info.packed_dim_block_size == 4); + // Channels-packed only: packed_dim must be the channels axis (WHCN dim 2). + VK_CHECK_COND(in_info.packed_dim == WHCN::kChannelsDim); + VK_CHECK_COND(out_info.packed_dim == WHCN::kChannelsDim); + + // Upscale factor: only r=2 is exercised by the model and tests. + const int32_t r = graph.extract_scalar(upscale_factor); + VK_CHECK_COND(r == 2); + + // Validate shape relationship: out = [N, C/r/r, H*r, W*r] given in = + // [N, C, H, W]. + const std::vector in_sizes = graph.sizes_of(packed_int8_input); + const std::vector out_sizes = graph.sizes_of(packed_int8_output); + VK_CHECK_COND(in_sizes[0] == out_sizes[0]); + VK_CHECK_COND(in_sizes[1] == out_sizes[1] * r * r); + VK_CHECK_COND(in_sizes[2] * r == out_sizes[2]); + VK_CHECK_COND(in_sizes[3] * r == out_sizes[3]); + + // Push constants + float scale_in = graph.extract_scalar(input_scale); + float scale_out_actual = 1.0f / graph.extract_scalar(output_inv_scale); + float inv_scale_out = graph.extract_scalar(output_inv_scale); + int32_t zp_in = graph.extract_scalar(input_zp); + int32_t zp_out = graph.extract_scalar(output_zp); + + // Detect the pure-byte-shuffle case: same scale & same zero-point. In that + // case the shader can skip the requantize math entirely. + // Use a small relative tolerance on the scales. + const float scale_diff = std::abs(scale_in - scale_out_actual); + const float scale_thresh = 1e-7f * std::max(std::abs(scale_in), 1e-7f); + int32_t passthrough = 0; + if (scale_diff <= scale_thresh && zp_in == zp_out) { + passthrough = 1; + } + int32_t r_val = r; + + std::vector push_constants = { + PushConstantDataInfo(&scale_in, sizeof(scale_in)), + PushConstantDataInfo(&inv_scale_out, sizeof(inv_scale_out)), + PushConstantDataInfo(&zp_in, sizeof(zp_in)), + PushConstantDataInfo(&zp_out, sizeof(zp_out)), + PushConstantDataInfo(&r_val, sizeof(r_val)), + PushConstantDataInfo(&passthrough, sizeof(passthrough)), + }; + + // UBOs + vkapi::ParamsBindList ubos; + ubos.append(graph.buffer_meta_ubo(packed_int8_output)); + ubos.append(graph.buffer_meta_ubo(packed_int8_input)); + + // Each thread writes one int32 output word (= 4 consecutive output channels + // at one (n, oh, ow) spatial position). Total threads = + // N * div_up_4(C_out) * H_out * W_out + // The custom global wg picker computes this from the output sizes; the + // shader internally re-derives the same decomposition from + // gl_GlobalInvocationID. The resize-args list still carries the upscale + // factor so the resize callback can stamp the output size. + const ValueRef r_resize_arg = static_cast(r); + + std::string kernel_name = "q8ta_pixel_shuffle"; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_pixel_shuffle_global_wg, + pick_q8ta_pixel_shuffle_local_wg, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + ubos, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(packed_int8_input), + graph.hashed_layout_of(packed_int8_output)}, + // Resize args: [upscale_factor] + {r_resize_arg}, + // Resizing Logic + resize_q8ta_pixel_shuffle_node)); +} + +// +// High level operator impl +// + +void q8ta_pixel_shuffle( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_inv_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef upscale_factor = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx); + + add_q8ta_pixel_shuffle_node( + graph, + packed_int8_input, + input_scale, + input_zp, + output_inv_scale, + output_zp, + upscale_factor, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_pixel_shuffle.default, q8ta_pixel_shuffle); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.h b/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.h new file mode 100644 index 00000000000..7522b71e042 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taPixelShuffle.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Fused PixelShuffle operating on int8x4 packed tensors. +// +// Replaces the decomposed chain: +// q8ta_dequantize -> view -> permute -> view -> q8ta_quantize +// with a single byte-shuffle (and optional requantize when scales differ). +// + +void add_q8ta_pixel_shuffle_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef output_inv_scale, + const ValueRef output_zp, + const ValueRef upscale_factor, + const ValueRef packed_int8_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taPixelShuffle.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taPixelShuffle.cpp new file mode 100644 index 00000000000..70dd8c9ca39 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taPixelShuffle.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +namespace { + +// Map a layout name string from the test driver to the corresponding +// channels-packed int8x4 GPUMemoryLayout enum. +// +// Note: PACKED_INT8_CONV2D is a Python/serialization-level alias that the +// runtime resolves to kPackedInt8_4C1W (see VulkanBackend.cpp). At the C++ +// runtime layer there is no distinct kPackedInt8_CONV2D enum, so testing +// "CONV2D" here would just be duplicate runtime work over "4C1W". We +// therefore only expose the two real enum values to the test driver. +utils::GPUMemoryLayout layout_from_string(const std::string& s) { + if (s == "4W4C") { + return utils::kPackedInt8_4W4C; + } else if (s == "4C1W") { + return utils::kPackedInt8_4C1W; + } + VK_THROW("Unknown q8ta layout name: " + s); +} + +} // namespace + +// +// Test op: takes a float input and float (re)quantization params, performs +// quantize -> fused pixel_shuffle -> dequantize, returns float output. +// +// The test op signature is: +// test_q8ta_pixel_shuffle( +// Tensor fp_input, +// float input_scale, +// int input_zp, +// float output_scale, +// int output_zp, +// int upscale_factor, +// str in_layout, // "4W4C" | "4C1W" +// str out_layout, // "4W4C" | "4C1W" +// ) -> Tensor (float, output shape = pixel_shuffle(in, r)) +// + +void test_q8ta_pixel_shuffle( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef upscale_factor = args.at(idx++); + const ValueRef in_layout_ref = args.at(idx++); + const ValueRef out_layout_ref = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + const int32_t r = graph.extract_scalar(upscale_factor); + + const std::string in_layout_str = graph.extract_string(in_layout_ref); + const std::string out_layout_str = graph.extract_string(out_layout_ref); + const utils::GPUMemoryLayout in_layout = layout_from_string(in_layout_str); + const utils::GPUMemoryLayout out_layout = layout_from_string(out_layout_str); + + const std::vector in_sizes = graph.sizes_of(fp_input); + VK_CHECK_COND(in_sizes.size() == 4); + const int64_t N = in_sizes[0]; + const int64_t C = in_sizes[1]; + const int64_t H = in_sizes[2]; + const int64_t W = in_sizes[3]; + std::vector out_sizes = {N, C / (r * r), H * r, W * r}; + + // Quantize fp_input to int8x4 with the channels-packed input layout. + TmpTensor q_in(&graph, in_sizes, vkapi::kInt8x4, utils::kBuffer, in_layout); + add_q8ta_quantize_node(graph, fp_input, input_scale, input_zp, q_in); + + // int8x4 output tensor with the channels-packed output layout. + TmpTensor q_out( + &graph, out_sizes, vkapi::kInt8x4, utils::kBuffer, out_layout); + + // Fused fast path. The fused kernel takes inv_scale, so compute it from + // output_scale here. + float output_scale_val = graph.extract_scalar(output_scale); + float output_inv_scale_val = 1.0f / output_scale_val; + ValueRef inv_scale_ref = graph.add_scalar(output_inv_scale_val); + add_q8ta_pixel_shuffle_node( + graph, + q_in, + input_scale, + input_zp, + inv_scale_ref, + output_zp, + upscale_factor, + q_out); + + // Dequantize back to fp for correctness comparison. + add_q8ta_dequantize_node(graph, q_out, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + test_etvk.test_q8ta_pixel_shuffle.default, test_q8ta_pixel_shuffle); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 5fb0f7f4cbf..b222d475f62 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -99,6 +99,7 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("test_q8ta_linear") define_custom_op_test_binary("test_q8ta_conv2d_transposed") + define_custom_op_test_binary("test_q8ta_pixel_shuffle") define_custom_op_test_binary("test_mm") define_custom_op_test_binary("test_conv2d_pw") define_custom_op_test_binary("test_conv2d_dw") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_pixel_shuffle.cpp b/backends/vulkan/test/custom_ops/test_q8ta_pixel_shuffle.cpp new file mode 100644 index 00000000000..a21912518c7 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_pixel_shuffle.cpp @@ -0,0 +1,304 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +// Test op signature in TestQ8taPixelShuffle.cpp: +// test_q8ta_pixel_shuffle(fp_in, in_scale, in_zp, out_scale, out_zp, +// upscale_factor, in_layout, out_layout) -> fp_out +// Implementation: fused fast-path kernel. The in_layout / out_layout strings +// select the channels-packed int8x4 layout used for the temporary quantized +// tensors. Supported values: "4W4C", "4C1W". +// (PACKED_INT8_CONV2D is a Python/serialization-level alias that the runtime +// resolves to kPackedInt8_4C1W, so it is not exercised separately here -- it +// would only re-test the same C++ kernel path as "4C1W".) + +struct PixelShuffleConfig { + std::vector in_shape; // [N, C*r*r, H, W] + int upscale_factor; + bool same_qparams; // if true, in_scale == out_scale and in_zp == out_zp + std::string in_layout = "4W4C"; + std::string out_layout = "4W4C"; + std::string test_case_name = "ACCU"; + std::string op_name = "test_q8ta_pixel_shuffle"; +}; + +TestCase create_test_case_from_config(const PixelShuffleConfig& config) { + TestCase test_case; + + std::string shape_str = shape_string(config.in_shape); + std::string qp_label = config.same_qparams ? "same_qp" : "diff_qp"; + std::string layout_label = + "[" + config.in_layout + "->" + config.out_layout + "]"; + std::string test_name = config.test_case_name + " In=" + shape_str + + " r=" + std::to_string(config.upscale_factor) + " " + qp_label + " " + + layout_label; + test_case.set_name(test_name); + + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // FP input: shape [N, C_in, H, W] + ValueSpec input_tensor( + config.in_shape, + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::RANDOM); + + float input_scale_val = 0.007112f; + ValueSpec input_scale(input_scale_val); + int32_t input_zp_val = 0; + ValueSpec input_zp(input_zp_val); + + float output_scale_val = config.same_qparams ? input_scale_val : 0.013f; + ValueSpec output_scale(output_scale_val); + int32_t output_zp_val = config.same_qparams ? input_zp_val : 5; + ValueSpec output_zp(output_zp_val); + + ValueSpec upscale_factor(static_cast(config.upscale_factor)); + + ValueSpec in_layout_spec = ValueSpec::make_string(config.in_layout); + ValueSpec out_layout_spec = ValueSpec::make_string(config.out_layout); + + // Output shape + std::vector out_shape = { + config.in_shape[0], + config.in_shape[1] / (config.upscale_factor * config.upscale_factor), + config.in_shape[2] * config.upscale_factor, + config.in_shape[3] * config.upscale_factor}; + ValueSpec output_tensor( + out_shape, + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked, + DataGenType::ZEROS); + + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zp); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zp); + test_case.add_input_spec(upscale_factor); + test_case.add_input_spec(in_layout_spec); + test_case.add_input_spec(out_layout_spec); + test_case.add_output_spec(output_tensor); + + // Tolerance: ~1 quant step in the bigger of the two scales. + float tol = std::max(input_scale_val, output_scale_val) + 1e-4f; + test_case.set_abs_tolerance(tol); + + // Filter shaders that are just measurement overhead (staging copies and the + // surrounding quantize/dequantize that wrap the operation under test). + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// All (in_layout, out_layout) pairs across the channels-packed int8 family. +// CONV2D is a Python-level alias that resolves to 4C1W at the C++ runtime, so +// it is not listed separately -- it would just re-run the 4C1W kernel path. +static const std::vector>& +get_layout_pairs() { + static const std::vector> layout_pairs = { + {"4W4C", "4W4C"}, + {"4W4C", "4C1W"}, + {"4C1W", "4W4C"}, + {"4C1W", "4C1W"}, + }; + return layout_pairs; +} + +std::vector generate_correctness_cases() { + std::vector test_cases; + + // Small shapes, all use r=2 (the only factor needed by the model). + // Shape format is the *input* shape [N, C_in, H, W] where C_in = C_out * r*r. + std::vector> shapes = { + // Small even W to be a multiple of 4 after upscaling. + {1, 16, 4, 4}, // out: [1, 4, 8, 8] + {1, 24, 8, 4}, // out: [1, 6, 16, 8] + {1, 32, 12, 8}, // out: [1, 8, 24, 16] + {1, 96, 16, 9}, // out: [1, 24, 32, 18] - first model shape + }; + + for (const auto& shape : shapes) { + for (bool same_qp : {true, false}) { + for (const auto& layouts : get_layout_pairs()) { + PixelShuffleConfig cfg; + cfg.in_shape = shape; + cfg.upscale_factor = 2; + cfg.same_qparams = same_qp; + cfg.in_layout = layouts.first; + cfg.out_layout = layouts.second; + cfg.test_case_name = "ACCU"; + + test_cases.push_back(create_test_case_from_config(cfg)); + } + } + } + + return test_cases; +} + +std::vector generate_perf_cases() { + std::vector test_cases; + + // Model perf shapes (output shapes from the RefineNet decoder; we compute + // the input shape as [N, C_out * r*r, H_out / r, W_out / r]). + // Output shapes: [1, 24, 32, 18], [1, 24, 64, 36], [1, 24, 128, 72], + // [1, 24, 256, 144]. For r=2, in shapes = [1, 96, 16, 9], etc. + std::vector> in_shapes = { + {1, 96, 16, 9}, + {1, 96, 32, 18}, + {1, 96, 64, 36}, + {1, 96, 128, 72}, + }; + + for (const auto& shape : in_shapes) { + for (const auto& layouts : get_layout_pairs()) { + PixelShuffleConfig cfg; + cfg.in_shape = shape; + cfg.upscale_factor = 2; + cfg.same_qparams = true; // residual-style: scales match + cfg.in_layout = layouts.first; + cfg.out_layout = layouts.second; + cfg.test_case_name = "PERF"; + + test_cases.push_back(create_test_case_from_config(cfg)); + } + } + + return test_cases; +} + +std::vector generate_all_cases() { + std::vector all = generate_correctness_cases(); + std::vector perf = generate_perf_cases(); + for (auto& tc : perf) { + all.push_back(tc); + } + return all; +} + +// Reference: quantize input, do PyTorch-equivalent pixel shuffle, requantize, +// then dequantize for comparison. +void q8ta_pixel_shuffle_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zp_spec = test_case.inputs()[idx++]; + const ValueSpec& upscale_factor_spec = test_case.inputs()[idx++]; + + ValueSpec& output_spec = test_case.outputs()[0]; + + const auto in_sizes = input_spec.get_tensor_sizes(); + for (auto d : in_sizes) { + if (d > kRefDimSizeLimit) { + throw std::invalid_argument("Dim exceeds reference compute limit"); + } + } + + const int64_t N = in_sizes[0]; + const int64_t C_in = in_sizes[1]; + const int64_t H_in = in_sizes[2]; + const int64_t W_in = in_sizes[3]; + const int32_t r = upscale_factor_spec.get_int_value(); + const int64_t C_out = C_in / (r * r); + const int64_t H_out = H_in * r; + const int64_t W_out = W_in * r; + + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zp = input_zp_spec.get_int_value(); + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zp = output_zp_spec.get_int_value(); + const int32_t qmin = -128; + const int32_t qmax = 127; + + const auto& input_data = input_spec.get_float_data(); + auto& ref = output_spec.get_ref_float_data(); + ref.resize(N * C_out * H_out * W_out); + + for (int64_t n = 0; n < N; ++n) { + for (int64_t c_out = 0; c_out < C_out; ++c_out) { + for (int64_t oh = 0; oh < H_out; ++oh) { + for (int64_t ow = 0; ow < W_out; ++ow) { + const int64_t c_in = c_out * r * r + (oh % r) * r + (ow % r); + const int64_t ih = oh / r; + const int64_t iw = ow / r; + const int64_t in_idx = ((n * C_in + c_in) * H_in + ih) * W_in + iw; + const float fp_in = input_data[in_idx]; + + // Quantize with input qparams + float qf = std::round(fp_in / input_scale) + input_zp; + qf = std::max(qf, static_cast(qmin)); + qf = std::min(qf, static_cast(qmax)); + int32_t q_in = static_cast(qf); + + // Dequantize back to fp using the input qparams (this models the + // dequantize node in the chain) + float dq = (q_in - input_zp) * input_scale; + + // Requantize to int8 with output qparams + float rqf = std::round(dq / output_scale) + output_zp; + rqf = std::max(rqf, static_cast(qmin)); + rqf = std::min(rqf, static_cast(qmax)); + int32_t q_out = static_cast(rqf); + + // Final dequantize to fp for comparison + float fp_out = (q_out - output_zp) * output_scale; + + const int64_t out_idx = + ((n * C_out + c_out) * H_out + oh) * W_out + ow; + ref[out_idx] = fp_out; + } + } + } + } +} + +int main(int /*argc*/, char* /*argv*/[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q8TA PixelShuffle Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q8ta_pixel_shuffle_reference_impl; + + auto results = execute_test_cases( + generate_all_cases, + "Q8taPixelShuffle", + 3, // warmup + 10, // benchmark + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index e3096f13b62..a5a0e2647a2 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -834,6 +834,29 @@ def get_upsample_bilinear2d_inputs(): return VkTestSuite(inputs_list) +@register_test_suite("aten.pixel_shuffle.default") +def get_pixel_shuffle_inputs(): + test_suite = VkTestSuite( + [ + # (input tensor shape (N, C*r*r, H, W), upscale_factor r) + ((1, 4, 2, 2), 2), + ((1, 9, 3, 3), 3), + ((1, 16, 2, 2), 4), + ((2, 4, 3, 5), 2), + ((1, 8, 4, 4), 2), + ((1, 12, 3, 4), 2), + ] + ) + test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] + test_suite.layouts = [ + "utils::kChannelsPacked", + "utils::kWidthPacked", + "utils::kHeightPacked", + ] + test_suite.dtypes = ["at::kFloat", "at::kHalf"] + return test_suite + + @register_test_suite(["aten.full.default", "aten.full_like.default"]) def get_full_inputs(): test_suite = VkTestSuite( diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index bcd240d8d12..fa448102b8e 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -403,3 +403,100 @@ def forward(self, x): 1, "Expected non-aligned OC linear to fuse into q8ta_linear_gemv", ) + + def test_fuse_quantized_pixel_shuffle(self): + """An un-decomposed pixel_shuffle wrapped in dequantize/quantize_per_tensor + ops should fuse into a single et_vk.q8ta_pixel_shuffle.default node, and + none of the original quant/dequant nodes should remain. + + The matcher relies on the partitioner's `ops_to_not_decompose()` hook + keeping `aten.pixel_shuffle.default` intact through edge lowering. We + replicate that behaviour here via `EdgeCompileConfig.preserve_ops` so + the test exercises the same graph shape that the partitioner produces + end-to-end. + """ + + class PixelShuffleModule(torch.nn.Module): + def forward(self, x): + x_dq = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 0.1, 0, -128, 127, torch.int8 + ) + y = torch.nn.functional.pixel_shuffle(x_dq, 2) + return torch.ops.quantized_decomposed.quantize_per_tensor( + y, 0.05, 1, -128, 127, torch.int8 + ) + + # Use a non-square H/W and a W that is not a multiple of 4 so the + # geometry checks exercise the same shapes the model uses. + x = torch.randint(-128, 127, (1, 96, 16, 9), dtype=torch.int8) + program = torch.export.export(PixelShuffleModule(), (x,), strict=True) + edge_program = to_edge( + program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + preserve_ops=[torch.ops.aten.pixel_shuffle.default], + ), + ) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + result = fuse_pass.call(ep.graph_module) + + self.assertTrue(result.modified) + + gm = ep.graph_module + self.assertEqual(op_node_count(gm, "q8ta_pixel_shuffle.default"), 1) + self.assertEqual(op_node_count(gm, "view_copy.default"), 0) + self.assertEqual(op_node_count(gm, "permute_copy.default"), 0) + self.assertEqual(op_node_count(gm, "pixel_shuffle.default"), 0) + self.assertEqual(op_node_count(gm, "dequantize_per_tensor.default"), 0) + self.assertEqual(op_node_count(gm, "quantize_per_tensor.default"), 0) + + # Verify the fused op carries the correct args. + fused_node = next( + n + for n in gm.graph.nodes + if get_target_canonical_name(n) == "q8ta_pixel_shuffle.default" + ) + # args = (input, input_scale, input_zp, inv_output_scale, output_zp, r) + self.assertEqual(fused_node.args[1], 0.1) + self.assertEqual(fused_node.args[2], 0) + # 1.0 / 0.05 == 20.0 + self.assertEqual(fused_node.args[3], 20.0) + self.assertEqual(fused_node.args[4], 1) + self.assertEqual(fused_node.args[5], 2) + + def test_quantized_pixel_shuffle_pattern_rejects_non_match(self): + """A `dq -> relu -> q` chain (no pixel_shuffle in between) must NOT be + fused. The new matcher only triggers when a single + `aten.pixel_shuffle.default` node sits between the dequant/quant pair. + """ + + class NonPixelShuffleModule(torch.nn.Module): + def forward(self, x): + x_dq = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 0.1, 0, -128, 127, torch.int8 + ) + y = torch.nn.functional.relu(x_dq) + return torch.ops.quantized_decomposed.quantize_per_tensor( + y, 0.1, 0, -128, 127, torch.int8 + ) + + x = torch.randint(-128, 127, (1, 96, 16, 9), dtype=torch.int8) + program = torch.export.export(NonPixelShuffleModule(), (x,), strict=True) + edge_program = to_edge( + program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + preserve_ops=[torch.ops.aten.pixel_shuffle.default], + ), + ) + + ep = edge_program._edge_programs["forward"] + fuse_pass = FusePatternsPass() + fuse_pass._exported_program = ep + fuse_pass.call(ep.graph_module) + + gm = ep.graph_module + self.assertEqual(op_node_count(gm, "q8ta_pixel_shuffle.default"), 0) diff --git a/backends/vulkan/test/test_vulkan_tensor_repr.py b/backends/vulkan/test/test_vulkan_tensor_repr.py index 64d7542b788..5a0fc664c17 100644 --- a/backends/vulkan/test/test_vulkan_tensor_repr.py +++ b/backends/vulkan/test/test_vulkan_tensor_repr.py @@ -649,7 +649,7 @@ def test_no_sync_primary_io_when_different_repsets(self): # -- Scalar args are skipped -- def test_scalar_arg_skipped(self): - """Non-tensor args should be treated as ALL_STORAGES_REPSET.""" + """Non-tensor args should be treated as ANY_STORAGE_INCL_PACKED_INT8.""" tensor_arg = _make_tensor_arg_node((1, 3, 8, 8)) # Second arg is a scalar (float) scalar_arg = 1.0 @@ -666,8 +666,8 @@ def test_scalar_arg_skipped(self): DEFAULT_TEXTURE_LIMITS, ) self.assertFalse(op_repsets.any_is_empty()) - # The scalar arg should get ALL_STORAGES_REPSET - # self.assertEqual(op_repsets.get_arg_repset(1), ALL_STORAGES_REPSET, f"""{op_repsets.get_arg_repset(1)}""") + # The scalar arg should get ANY_STORAGE_INCL_PACKED_INT8 + # self.assertEqual(op_repsets.get_arg_repset(1), ANY_STORAGE_INCL_PACKED_INT8, f"""{op_repsets.get_arg_repset(1)}""") # -- pick_representations -- diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index f93fec167eb..7febff260c6 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1203,8 +1203,15 @@ def filter_invalid_reprs_for_node_list( # Special use RepSets NO_STORAGE = TensorRepSet(set(), set()) -ALL_STORAGES_REPSET = TensorRepSet( - universal_memory_layout_set, universal_memory_layout_set +# Buffer side admits both float and quantized (PACKED_INT8_*) layouts; texture side +# is float-only because the Vulkan backend has no quantized texture support +# (required_image_extents and the texture indexing helpers only know about the +# float layouts). Used as an intersection identity (e.g. common_arg_repset +# accumulator) and as a placeholder for non-tensor / not-yet-prepacked args, so +# narrowing the texture side is non-breaking for those uses while letting it act +# as a true universal set when intersected against quant-aware repsets. +ANY_STORAGE_INCL_PACKED_INT8 = TensorRepSet( + universal_memory_layout_set, all_memory_layouts ) @@ -1330,19 +1337,19 @@ def __init__( # noqa: C901 # Now, go through the arguments of the operator and create a filtered repset # for each based on the actual tensor value. args_repset_list = TensorRepSetList([]) - common_arg_repset = ALL_STORAGES_REPSET + common_arg_repset = ANY_STORAGE_INCL_PACKED_INT8 for i, arg_node in enumerate(op_node.args): arg_repset = inputs_repsets[i] - # Use ALL_STORAGES_REPSET for non-tensor nodes so they don't cause the op + # Use ANY_STORAGE_INCL_PACKED_INT8 for non-tensor nodes so they don't cause the op # repsets to appear empty if not is_tensor_arg_node(arg_node): - args_repset_list.append(ALL_STORAGES_REPSET) + args_repset_list.append(ANY_STORAGE_INCL_PACKED_INT8) # NO_STORAGE is used to denote that an input is either a non tensor arg or # a weight tensor that is not prepacked. Similar to the above, use - # ALL_STORAGES_REPSET in this case. + # ANY_STORAGE_INCL_PACKED_INT8 in this case. elif arg_repset.is_empty(): - args_repset_list.append(ALL_STORAGES_REPSET) + args_repset_list.append(ANY_STORAGE_INCL_PACKED_INT8) else: assert not arg_repset.is_empty() @@ -1355,7 +1362,7 @@ def __init__( # noqa: C901 # Repeat for output tensors. outs_repset_list = TensorRepSetList([]) - common_out_repset = ALL_STORAGES_REPSET + common_out_repset = ANY_STORAGE_INCL_PACKED_INT8 if num_tensors_in_node(op_node) == 1: common_out_repset = filter_invalid_reprs( op_node.meta["val"], outputs_repsets[0], texture_limits