Skip to content

Commit 6730837

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Support different input layouts in q8ta_binary operator
Previously, the q8ta_binary operator required both inputs to use the same memory layout. This was enforced by using a single `in_layout` specialization constant for both input buffers. However, some models may have inputs with different layouts (e.g., 4W4C and 4C1W) that share the same packed dimension and block size, which should be compatible for binary operations. This change introduces a separate `other_layout` specialization constant for the second input, allowing the shader to correctly load from input_b using its actual layout while input_a continues to use `in_layout`. The C++ side now passes both layout hashes as separate specialization constants to the shader. Differential Revision: [D93768638](https://our.internmc.facebook.com/intern/diff/D93768638/) ghstack-source-id: 342806076 Pull Request resolved: #17563
1 parent 8a10718 commit 6730837

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

backends/vulkan/runtime/graph/ops/glsl/q8ta_binary.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4646

4747
${layout_declare_spec_const(C, "int", "out_layout", "CONTIG_LAYOUT_INT")}
4848
${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")}
49+
${layout_declare_spec_const(C, "int", "other_layout", "CONTIG_LAYOUT_INT")}
4950
${layout_declare_spec_const(C, "int", "block_config", "0")}
5051

5152
// Generate loading functions for input buffers
@@ -71,7 +72,7 @@ void main() {
7172
ivec4 in_block_a = load_int8x4_block_from_t_in_a(
7273
in_a_meta, tidx, in_layout, block_outer_dim);
7374
ivec4 in_block_b = load_int8x4_block_from_t_in_b(
74-
in_b_meta, tidx, in_layout, block_outer_dim);
75+
in_b_meta, tidx, other_layout, block_outer_dim);
7576

7677
ivec4 out_block;
7778

backends/vulkan/runtime/graph/ops/impl/Q8taBinary.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void add_q8ta_binary_node(
4242

4343
VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim);
4444
VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim);
45+
4546
VK_CHECK_COND(
4647
input_a_info.packed_dim_block_size == output_info.packed_dim_block_size);
4748
VK_CHECK_COND(
@@ -105,6 +106,7 @@ void add_q8ta_binary_node(
105106
// Specialization Constants
106107
{graph.hashed_layout_of(packed_int8_output),
107108
graph.hashed_layout_of(packed_int8_input_a),
109+
graph.hashed_layout_of(packed_int8_input_b),
108110
block_config.as_packed_int()},
109111
// Resize args
110112
{block_config_ref},

0 commit comments

Comments
 (0)