Skip to content

Commit 1c319ca

Browse files
author
ssjia
committed
[ET-VK][q8_ops] Add int8x4_buffer_to_nchw shader and refactor Int8x4Staging
Renames Q8taStaging.cpp/h to Int8x4Staging.cpp/h and expands it to cover the full staging lifecycle for kInt8x4 buffer tensors. **Rename and split of the old prepack function:** The old `add_staging_to_int8x4_buffer_node` (which used a static dispatch node for prepacking TensorRef data into a packed int8x4 buffer) is renamed to `add_prepack_int8x4_buffer_node` to clarify its role. Two new runtime staging functions are added alongside it: - `add_staging_to_int8x4_buffer_node`: reads NCHW data from a staging buffer into a kInt8x4 buffer tensor at execute time, using a `DynamicDispatchNode` wrapping the existing `nchw_to_int8x4_buffer` shader. - `add_int8x4_buffer_to_staging_node`: writes packed int8x4 data back from a kInt8x4 buffer tensor to a contiguous NCHW staging buffer at execute time, using a new `int8x4_buffer_to_nchw` shader. **New shader (int8x4_buffer_to_nchw.glsl):** Implements the reverse of `nchw_to_int8x4_buffer`. One thread per output int32 in the NCHW staging buffer. For each thread it decodes 4 NCHW-ordered element indices, looks up each element's position in the packed int8x4 buffer via `tensor4d_idx_to_buf_idx`, extracts the packed byte, and assembles 4 bytes into a single output int32. Works for any GPUMemoryLayout. **Staging.cpp dispatch:** `add_staging_to_tensor_node` and `add_tensor_to_staging_node` now both dispatch to the int8x4-specific functions when the tensor dtype is kInt8x4. `prepack_op` is updated to call `add_prepack_int8x4_buffer_node`. **TestQ8taBinary.cpp** is updated to include Int8x4Staging.h and call `add_prepack_int8x4_buffer_node`. Differential Revision: [D94364640](https://our.internmc.facebook.com/intern/diff/D94364640/) [ghstack-poisoned]
1 parent 588f2a6 commit 1c319ca

8 files changed

Lines changed: 333 additions & 54 deletions

File tree

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_active_storage_type("buffer")}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing.glslh"
18+
19+
// Output staging buffer: raw int8 data interpreted as int32 for device compat
20+
${layout_declare_tensor(B, "w", "nchw_out", "int", "buffer")}
21+
// Input buffer: packed int8x4 values (each int32 contains 4 packed int8)
22+
${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")}
23+
24+
// Metadata for input tensor
25+
${layout_declare_ubo(B, "BufferMetadata", "inp")}
26+
27+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
28+
29+
${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")}
30+
31+
void main() {
32+
// One thread per output int32 in the NCHW staging buffer.
33+
// Each output int32 holds 4 consecutive NCHW bytes.
34+
const uint out_int32_idx = gl_GlobalInvocationID.x;
35+
36+
const uint W = inp.sizes[0][0];
37+
const uint H = inp.sizes[0][1];
38+
const uint C = inp.sizes[0][2];
39+
const uint N = inp.sizes[0][3];
40+
const uint total_numel = W * H * C * N;
41+
const uint num_out_int32s = (total_numel + 3u) / 4u;
42+
43+
if (out_int32_idx >= num_out_int32s) {
44+
return;
45+
}
46+
47+
int output_int32 = 0;
48+
[[unroll]] for (int j = 0; j < 4; ++j) {
49+
const uint nchw_idx = out_int32_idx * 4u + uint(j);
50+
if (nchw_idx >= total_numel) {
51+
break;
52+
}
53+
54+
// Convert NCHW linear index to tensor4D (WHCN) coordinates.
55+
const uint w = nchw_idx % W;
56+
const uint h = (nchw_idx / W) % H;
57+
const uint c = (nchw_idx / (W * H)) % C;
58+
const uint n = nchw_idx / (W * H * C);
59+
60+
TensorIndex4D tidx;
61+
tidx.data = ivec4(int(w), int(h), int(c), int(n));
62+
63+
// tensor4d_idx_to_buf_idx returns a linear element index where
64+
// element_index / 4 is the int32 slot and element_index % 4 is the byte
65+
// position within that int32. This matches the packing order used by
66+
// nchw_to_int8x4_buffer when writing to the int8x4 buffer.
67+
const int elem_buf_idx = tensor4d_idx_to_buf_idx(inp, tidx, inp_layout);
68+
const int int8_val =
69+
(t_inp[elem_buf_idx / 4] >> ((elem_buf_idx % 4) * 8)) & 0xFF;
70+
71+
output_int32 |= (int8_val << (j * 8));
72+
}
73+
74+
nchw_out[out_int32_idx] = output_int32;
75+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
int8x4_buffer_to_nchw:
8+
parameter_names_with_default_values:
9+
DTYPE: int
10+
shader_variants:
11+
- NAME: int8x4_buffer_to_nchw
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
14+
15+
namespace vkcompute {
16+
17+
void add_prepack_int8x4_buffer_node(
18+
ComputeGraph& graph,
19+
const ValueRef tensor_data,
20+
const ValueRef tensor) {
21+
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
22+
// TODO(ssjia): Update shaders to handle high-dim tensors
23+
VK_CHECK_COND(graph.dim_of(tensor) <= 4);
24+
25+
std::string kernel_name = "nchw_to_int8x4_buffer";
26+
27+
vkapi::ParamsBindList param_buffers;
28+
param_buffers.append(graph.buffer_meta_ubo(tensor));
29+
30+
// One thread per texel (each texel = one int32 = 4 packed int8).
31+
// Use padded_numel to account for dimension padding in packed int8 layouts
32+
// (e.g., kPackedInt8_4C with C=3 pads to C=4).
33+
uint32_t num_texels =
34+
utils::safe_downcast<uint32_t>(graph.padded_numel_of(tensor) / 4);
35+
utils::uvec3 global_wg_size = {num_texels, 1, 1};
36+
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
37+
38+
graph.prepack_nodes().emplace_back(new PrepackNode(
39+
graph,
40+
VK_KERNEL_FROM_STR(kernel_name),
41+
global_wg_size,
42+
local_wg_size,
43+
// Input and Output
44+
tensor_data,
45+
tensor,
46+
// Parameter Buffers
47+
param_buffers,
48+
// Specialization Constants
49+
{graph.hashed_layout_of(tensor)}));
50+
}
51+
52+
static utils::uvec3 staging_to_int8x4_buffer_global_wg_size(
53+
ComputeGraph* graph,
54+
const vkapi::ShaderInfo& shader,
55+
const std::vector<ArgGroup>& args,
56+
const std::vector<ValueRef>& resize_args) {
57+
(void)shader;
58+
(void)resize_args;
59+
const ValueRef out_tensor = args.at(0).refs.at(0);
60+
const uint32_t num_texels =
61+
utils::safe_downcast<uint32_t>(graph->padded_numel_of(out_tensor) / 4);
62+
return {num_texels, 1, 1};
63+
}
64+
65+
void add_staging_to_int8x4_buffer_node(
66+
ComputeGraph& graph,
67+
const ValueRef in_staging,
68+
const ValueRef tensor) {
69+
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
70+
// TODO(ssjia): Update shaders to handle high-dim tensors
71+
VK_CHECK_COND(graph.dim_of(tensor) <= 4);
72+
73+
vkapi::ParamsBindList param_buffers;
74+
param_buffers.append(graph.buffer_meta_ubo(tensor));
75+
76+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
77+
graph,
78+
VK_KERNEL_FROM_STR("nchw_to_int8x4_buffer"),
79+
staging_to_int8x4_buffer_global_wg_size,
80+
default_pick_local_wg_size,
81+
// Input and Output
82+
{{tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}},
83+
// Parameter Buffers
84+
param_buffers,
85+
// Push Constants
86+
{},
87+
// Specialization Constants
88+
{graph.hashed_layout_of(tensor)},
89+
// Resize Args
90+
{},
91+
// Resizing Logic
92+
nullptr));
93+
}
94+
95+
static utils::uvec3 int8x4_buffer_to_staging_global_wg_size(
96+
ComputeGraph* graph,
97+
const vkapi::ShaderInfo& shader,
98+
const std::vector<ArgGroup>& args,
99+
const std::vector<ValueRef>& resize_args) {
100+
(void)shader;
101+
(void)resize_args;
102+
const ValueRef in_tensor = args.at(1).refs.at(0);
103+
// One thread per output int32 in the NCHW staging buffer.
104+
const int32_t numel = graph->numel_of(in_tensor);
105+
const uint32_t num_out_int32s =
106+
utils::safe_downcast<uint32_t>((numel + 3) / 4);
107+
return {num_out_int32s, 1, 1};
108+
}
109+
110+
void add_int8x4_buffer_to_staging_node(
111+
ComputeGraph& graph,
112+
const ValueRef tensor,
113+
const ValueRef staging_data) {
114+
VK_CHECK_COND(graph.dtype_of(tensor) == vkapi::kInt8x4);
115+
// TODO(ssjia): Update shaders to handle high-dim tensors
116+
VK_CHECK_COND(graph.dim_of(tensor) <= 4);
117+
118+
vkapi::ParamsBindList param_buffers;
119+
param_buffers.append(graph.buffer_meta_ubo(tensor));
120+
121+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
122+
graph,
123+
VK_KERNEL_FROM_STR("int8x4_buffer_to_nchw"),
124+
int8x4_buffer_to_staging_global_wg_size,
125+
default_pick_local_wg_size,
126+
// Input and Output
127+
{{staging_data, vkapi::kWrite}, {tensor, vkapi::kRead}},
128+
// Parameter Buffers
129+
param_buffers,
130+
// Push Constants
131+
{},
132+
// Specialization Constants
133+
{graph.hashed_layout_of(tensor)},
134+
// Resize Args
135+
{},
136+
// Resizing Logic
137+
nullptr));
138+
}
139+
140+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h renamed to backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,19 @@
1212

1313
namespace vkcompute {
1414

15-
void add_staging_to_int8x4_buffer_node(
15+
void add_prepack_int8x4_buffer_node(
1616
ComputeGraph& graph,
1717
const ValueRef tensor_data,
1818
const ValueRef tensor);
1919

20+
void add_staging_to_int8x4_buffer_node(
21+
ComputeGraph& graph,
22+
const ValueRef in_staging,
23+
const ValueRef tensor);
24+
25+
void add_int8x4_buffer_to_staging_node(
26+
ComputeGraph& graph,
27+
const ValueRef tensor,
28+
const ValueRef staging_data);
29+
2030
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Q8taStaging.cpp

Lines changed: 0 additions & 49 deletions
This file was deleted.

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
15-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>
1616
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1717
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
1818

@@ -27,6 +27,10 @@ void add_staging_to_tensor_node(
2727
const ValueRef out_tensor) {
2828
VK_CHECK_COND(graph.val_is_staging(in_staging));
2929

30+
if (graph.dtype_of(out_tensor) == vkapi::kInt8x4) {
31+
return add_staging_to_int8x4_buffer_node(graph, in_staging, out_tensor);
32+
}
33+
3034
vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(
3135
graph,
3236
out_tensor,
@@ -104,6 +108,10 @@ void add_tensor_to_staging_node(
104108
const ValueRef out_staging) {
105109
VK_CHECK_COND(graph.val_is_staging(out_staging));
106110

111+
if (graph.dtype_of(in_tensor) == vkapi::kInt8x4) {
112+
return add_int8x4_buffer_to_staging_node(graph, in_tensor, out_staging);
113+
}
114+
107115
vkapi::ShaderInfo shader = get_tensor_to_nchw_shader(
108116
graph,
109117
in_tensor,
@@ -329,7 +337,7 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
329337

330338
void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
331339
if (graph.dtype_of(args[1]) == vkapi::kInt8x4) {
332-
return add_staging_to_int8x4_buffer_node(graph, args[0], args[1]);
340+
return add_prepack_int8x4_buffer_node(graph, args[0], args[1]);
333341
}
334342
return add_prepack_standard_node(graph, args[0], args[1]);
335343
}

backends/vulkan/test/custom_ops/impl/TestQ8taBinary.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Int8x4Staging.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taBinary.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h>
13-
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Q8taStaging.h>
1414

1515
namespace vkcompute {
1616

@@ -62,7 +62,7 @@ void q8ta_add_test(ComputeGraph& graph, const std::vector<ValueRef>& args) {
6262
if (input_b_is_int8) {
6363
// Input B is a pre-quantized int8 TensorRef; prepack directly into packed
6464
// int8x4 format
65-
add_staging_to_int8x4_buffer_node(graph, input_b, packed_int8_input_b);
65+
add_prepack_int8x4_buffer_node(graph, input_b, packed_int8_input_b);
6666
} else {
6767
// Input B is a float tensor; quantize at runtime
6868
add_q8ta_quantize_node(

0 commit comments

Comments
 (0)