Skip to content

Commit b8ba505

Browse files
author
ssjia
committed
[ET-VK][conv1d] Implement height-packed pointwise conv1d operator
Implement a new conv1d pointwise (kernel_size=1) operator using height-packed layout where channels are the packed dimension (WHCN dim 1). This enables dot-product reduction over input channels: each vec4 load gives 4 consecutive channel values, yielding 4 MACs per dot() instruction. Uses tiled computation with the FP tile infrastructure from linear/matmul (FPInputTile, FPWeightTile, FPOutTile, fp_accumulate_with_fp_weight) and 4OC×4IC blocked weight packing via pack_fp_linear_weight.glsl for cache-friendly texture2d weight reads. Adaptive tile_m selection (4/2/1 rows) based on GPU occupancy. Thread mapping: X=OC4 tiles, Y=L tiles, Z=batch. Each thread computes TILE_M×TILE_N4×4 output elements. Inner loop loads input tiles and packed weight tiles, then calls fp_accumulate_with_fp_weight for tiled FMA. Supports both buffer and texture3d storage for input/output, texture2d or buffer for packed weights, fp32/fp16, and optional bias. Registered as et_vk.conv1d_pw.default (standalone custom op for testing/benchmarking). Performance on Adreno 750 (S24): - [1,256,1024]x[512,256,1] texture f16: 908 GFLOP/s - [1,512,2048]x[256,512,1] texture f16: 865 GFLOP/s - [1,128,4096]x[128,128,1] texture f16: 781 GFLOP/s - [1,256,1024]x[512,256,1] buffer f16: 491 GFLOP/s Differential Revision: [D97344092](https://our.internmc.facebook.com/intern/diff/D97344092/) [ghstack-poisoned]
1 parent bf2243a commit b8ba505

6 files changed

Lines changed: 771 additions & 0 deletions

File tree

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
14+
15+
$if STORAGE == "buffer":
16+
#define OUTPUT_BUFFER
17+
#define INPUT_BUFFER
18+
$if WEIGHT_STORAGE == "buffer":
19+
#define WEIGHT_BUFFER
20+
$if HAS_BIAS:
21+
#define HAS_BIAS
22+
$if STORAGE == "buffer" and HAS_BIAS:
23+
#define BIAS_BUFFER
24+
25+
#define TILE_M4 ${TILE_M4}
26+
#define TILE_K4 ${TILE_K4}
27+
#define TILE_N4 ${TILE_N4}
28+
29+
#define TILE_M ${TILE_M}
30+
#define TILE_K ${TILE_K4 * 4}
31+
#define TILE_N ${TILE_N4 * 4}
32+
33+
${define_required_extensions(STORAGE, DTYPE)}
34+
$if WEIGHT_STORAGE != STORAGE:
35+
${define_required_extensions(WEIGHT_STORAGE, DTYPE)}
36+
37+
layout(std430) buffer;
38+
39+
#include "common.glslh"
40+
41+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
42+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
43+
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, WEIGHT_STORAGE, is_scalar_array=False)}
44+
$if HAS_BIAS:
45+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)}
46+
47+
// in_sizes: {L, C_in, N, 1} in WHCN order
48+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
49+
// out_sizes: {L, C_out, N, 1} in WHCN order
50+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
51+
$if HAS_BIAS:
52+
${layout_declare_ubo(B, "ivec4", "bias_sizes")}
53+
54+
$if HAS_BIAS:
55+
layout(push_constant) uniform restrict Block {
56+
int weight_B;
57+
float alpha;
58+
float beta;
59+
};
60+
$else:
61+
layout(push_constant) uniform restrict Block {
62+
int weight_B;
63+
};
64+
65+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
66+
67+
#include "linear_fp_input_tile.glslh"
68+
#include "linear_fp_weight_tile.glslh"
69+
#include "linear_fp_output_tile.glslh"
70+
#include "linear_fp_packed_weight_tile_load.glslh"
71+
#include "linear_fp_output_tile_fp_compute.glslh"
72+
73+
// Conv1d pointwise is matrix multiplication with swapped texture coordinates.
74+
// Linear: input ivec3(k4, m, b), output ivec3(n4, m, b) [width-packed]
75+
// Conv1d: input ivec3(m, k4, b), output ivec3(m, n4, b) [height-packed]
76+
// Buffer indexing is identical: (b * M + m) * K4 + k4
77+
78+
VEC4_T load_input_x4(
79+
const int k4,
80+
const int m,
81+
const int b,
82+
const int K4,
83+
const int M) {
84+
#ifdef INPUT_BUFFER
85+
return t_in[(b * M + m) * K4 + k4];
86+
#else
87+
return texelFetch(t_in, ivec3(m, k4, b), 0);
88+
#endif
89+
}
90+
91+
void load_input_tile_with_checks(
92+
out FPInputTile tile,
93+
const int k4_start,
94+
const int m_start,
95+
const int b,
96+
const int K4,
97+
const int M) {
98+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
99+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
100+
if (k4_start + k4 < K4 && m_start + m < M) {
101+
tile.data[m][k4] =
102+
load_input_x4(k4_start + k4, m_start + m, b, K4, M);
103+
} else {
104+
tile.data[m][k4] = VEC4_T(0.0);
105+
}
106+
}
107+
}
108+
}
109+
110+
void store_output_x4(
111+
const VEC4_T texel,
112+
const int n4,
113+
const int m,
114+
const int b,
115+
const int N4,
116+
const int M) {
117+
#ifdef OUTPUT_BUFFER
118+
t_out[(b * M + m) * N4 + n4] = texel;
119+
#else
120+
imageStore(t_out, ivec3(m, n4, b), texel);
121+
#endif
122+
}
123+
124+
void store_output_tile_with_checks(
125+
const FPOutTile out_tile,
126+
const int n4_start,
127+
const int m_start,
128+
const int b,
129+
const int N4,
130+
const int M) {
131+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
132+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
133+
if (m_start + m < M && n4_start + n4 < N4) {
134+
store_output_x4(
135+
out_tile.data[m][n4], n4_start + n4, m_start + m, b, N4, M);
136+
}
137+
}
138+
}
139+
}
140+
141+
void main() {
142+
// Thread mapping: X=OC4 (N4), Y=L/tile_m (M tiles), Z=batch
143+
const int tile_idx_n = int(gl_GlobalInvocationID.x);
144+
const int tile_idx_m = int(gl_GlobalInvocationID.y);
145+
146+
const int n4_start = tile_idx_n * TILE_N4;
147+
const int m_start = tile_idx_m * TILE_M;
148+
149+
// in_sizes: {L, C_in, N, 1} in WHCN
150+
const int K = in_sizes.y; // C_in
151+
const int M = in_sizes.x; // L
152+
const int K4 = div_up_4(K);
153+
// out_sizes: {L, C_out, N, 1} in WHCN
154+
const int N_out = out_sizes.y; // C_out
155+
const int N4 = div_up_4(N_out);
156+
157+
if (n4_start >= N4 || m_start >= M) {
158+
return;
159+
}
160+
161+
FPOutTile out_tile;
162+
initialize(out_tile);
163+
164+
FPInputTile in_tile;
165+
FPWeightTile w_tile;
166+
167+
const int b = int(gl_GlobalInvocationID.z);
168+
169+
for (int k4 = 0; k4 < K4; k4++) {
170+
load_input_tile_with_checks(in_tile, k4, m_start, b, K4, M);
171+
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
172+
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
173+
}
174+
175+
#ifdef HAS_BIAS
176+
// Load bias (per output channel, width-packed) and apply
177+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
178+
VEC4_T bias_val = VEC4_T(0.0);
179+
if (n4_start + n4 < N4) {
180+
#ifdef BIAS_BUFFER
181+
bias_val = t_bias[n4_start + n4];
182+
#else
183+
bias_val = texelFetch(t_bias, ivec3(n4_start + n4, 0, 0), 0);
184+
#endif
185+
}
186+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
187+
out_tile.data[m][n4] =
188+
VEC4_T(alpha) * out_tile.data[m][n4] + VEC4_T(beta) * bias_val;
189+
}
190+
}
191+
#endif
192+
193+
store_output_tile_with_checks(out_tile, n4_start, m_start, b, N4, M);
194+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv1d_pw:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
WEIGHT_STORAGE: texture2d
12+
HAS_BIAS: false
13+
TILE_M4: 1
14+
TILE_K4: 1
15+
TILE_N4: 1
16+
TILE_M: 4
17+
generate_variant_forall:
18+
combination:
19+
parameter_names: [STORAGE, WEIGHT_STORAGE]
20+
combos:
21+
- parameter_values: [texture3d, texture2d]
22+
- parameter_values: [texture3d, buffer]
23+
- parameter_values: [buffer, texture2d]
24+
- parameter_values: [buffer, buffer]
25+
DTYPE:
26+
- VALUE: float
27+
- VALUE: half
28+
shader_variants:
29+
- NAME: conv1d_pw
30+
- NAME: conv1d_pw_tile_row_2
31+
TILE_M: 2
32+
- NAME: conv1d_pw_tile_row_1
33+
TILE_M: 1
34+
- NAME: conv1d_pw_bias
35+
HAS_BIAS: true
36+
- NAME: conv1d_pw_bias_tile_row_2
37+
HAS_BIAS: true
38+
TILE_M: 2
39+
- NAME: conv1d_pw_bias_tile_row_1
40+
HAS_BIAS: true
41+
TILE_M: 1

0 commit comments

Comments
 (0)