Skip to content

Commit af5aaa5

Browse files
author
ssjia
committed
Update on "[ET-VK][conv1d] Implement height-packed depthwise conv1d operator"
Implement a depthwise conv1d operator using height-packed layout where channels are the packed dimension (WHCN dim 1). Depthwise conv applies a separate filter to each channel independently (groups=C), so 4 channels can be processed in parallel using element-wise vec4 FMA over kernel positions. Thread mapping: X=C/4, Y=L_out, Z=N. Each thread computes one output texel (4 channels at one spatial position). Inner loop iterates over kernel positions K with bounds-checked input access for padding. Weight [C,1,K] is prepacked as channels-packed so each vec4 load gives 4 channels' weights at one kernel position. Supports both buffer and texture3d storage, fp32/fp16, optional bias, and arbitrary stride/padding/dilation. Registered as et_vk.conv1d_dw.default (standalone custom op). Performance on Adreno 750 (S24): - [1,128,4096] K=31 buffer f16: 231 GFLOP/s - [1,128,4096] K=31 buffer f32: 155 GFLOP/s - [1,512,2048] K=5 buffer f32: 66 GFLOP/s Differential Revision: [D97344091](https://our.internmc.facebook.com/intern/diff/D97344091/) [ghstack-poisoned]
2 parents 0934d48 + 3eca3e2 commit af5aaa5

4 files changed

Lines changed: 86 additions & 14 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ layout(push_constant) uniform restrict Block {
3939
int stride;
4040
int padding;
4141
int dilation;
42+
float output_min;
43+
float output_max;
4244
};
4345

4446
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -86,6 +88,8 @@ void main() {
8688
#endif
8789
#endif
8890

91+
sum = clamp(sum, VEC4_T(output_min), VEC4_T(output_max));
92+
8993
#ifdef BUFFER
9094
t_out[(n * L_out + l_out) * C4 + c4] = sum;
9195
#else

backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,14 @@ $if HAS_BIAS:
5656
int weight_B;
5757
float alpha;
5858
float beta;
59+
float output_min;
60+
float output_max;
5961
};
6062
$else:
6163
layout(push_constant) uniform restrict Block {
6264
int weight_B;
65+
float output_min;
66+
float output_max;
6367
};
6468

6569
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -190,5 +194,13 @@ void main() {
190194
}
191195
#endif
192196

197+
// Apply activation clamp
198+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
199+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
200+
out_tile.data[m][n4] =
201+
clamp(out_tile.data[m][n4], VEC4_T(output_min), VEC4_T(output_max));
202+
}
203+
}
204+
193205
store_output_tile_with_checks(out_tile, n4_start, m_start, b, N4, M);
194206
}

backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1818

19+
#include <limits>
20+
1921
namespace vkcompute {
2022

2123
void resize_conv1d_dw_node(
@@ -48,6 +50,11 @@ struct Conv1dDWParams final {
4850
int32_t dilation;
4951
};
5052

53+
struct Conv1dDWClampParams final {
54+
float output_min;
55+
float output_max;
56+
};
57+
5158
utils::uvec3 pick_conv1d_dw_global_wg_size(
5259
ComputeGraph* graph,
5360
const vkapi::ShaderInfo& shader,
@@ -74,7 +81,9 @@ void add_conv1d_dw_node(
7481
const ValueRef stride_ref,
7582
const ValueRef padding_ref,
7683
const ValueRef dilation_ref,
77-
const ValueRef out) {
84+
const ValueRef out,
85+
const float output_min = std::numeric_limits<float>::lowest(),
86+
const float output_max = std::numeric_limits<float>::max()) {
7887
VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim);
7988
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim);
8089

@@ -103,6 +112,11 @@ void add_conv1d_dw_node(
103112
utils::safe_downcast<int32_t>(dilation_val),
104113
};
105114

115+
Conv1dDWClampParams clamp_params{
116+
output_min,
117+
output_max,
118+
};
119+
106120
std::string kernel_name = has_bias ? "conv1d_dw_bias" : "conv1d_dw";
107121
kernel_name.reserve(kShaderNameReserve);
108122
add_storage_type_suffix(kernel_name, storage_type);
@@ -123,7 +137,8 @@ void add_conv1d_dw_node(
123137
// Shader params buffers
124138
{graph.sizes_ubo(in), graph.sizes_ubo(out)},
125139
// Push Constants
126-
{PushConstantDataInfo(&params, sizeof(Conv1dDWParams))},
140+
{PushConstantDataInfo(&params, sizeof(Conv1dDWParams)),
141+
PushConstantDataInfo(&clamp_params, sizeof(Conv1dDWClampParams))},
127142
// Specialization Constants
128143
{},
129144
// Resize Args
@@ -132,17 +147,38 @@ void add_conv1d_dw_node(
132147
resize_conv1d_dw_node));
133148
}
134149

150+
// Args: in, weight, bias, stride, padding, dilation, groups,
151+
// output_min, output_max, out
152+
// output_min and output_max may be kDummyValueRef (no clamp).
135153
void conv1d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
136-
// args: in, weight, bias, stride, padding, dilation, groups, out
137154
ValueRef in = args[0];
138155
ValueRef weight = args[1];
139156
ValueRef bias = args[2];
140157
ValueRef stride = args[3];
141158
ValueRef padding = args[4];
142159
ValueRef dilation = args[5];
143-
ValueRef out = args[7];
160+
ValueRef out = args[9];
144161

145-
add_conv1d_dw_node(graph, in, weight, bias, stride, padding, dilation, out);
162+
float output_min = std::numeric_limits<float>::lowest();
163+
float output_max = std::numeric_limits<float>::max();
164+
if (is_valid(args[7])) {
165+
output_min = graph.extract_scalar<float>(args[7]);
166+
}
167+
if (is_valid(args[8])) {
168+
output_max = graph.extract_scalar<float>(args[8]);
169+
}
170+
171+
add_conv1d_dw_node(
172+
graph,
173+
in,
174+
weight,
175+
bias,
176+
stride,
177+
padding,
178+
dilation,
179+
out,
180+
output_min,
181+
output_max);
146182
}
147183

148184
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1818

19+
#include <limits>
20+
1921
namespace vkcompute {
2022

2123
// Minimum number of thread groups to target for good GPU occupancy.
@@ -117,11 +119,15 @@ void resize_conv1d_pw_node(
117119

118120
struct Conv1dPWIntParams final {
119121
int32_t weight_B;
122+
float output_min;
123+
float output_max;
120124
};
121125

122126
struct Conv1dPWBiasParams final {
123127
float alpha;
124128
float beta;
129+
float output_min;
130+
float output_max;
125131
};
126132

127133
vkapi::ShaderInfo pick_conv1d_pw_shader(
@@ -181,7 +187,9 @@ void add_conv1d_pw_node(
181187
const ValueRef in,
182188
const ValueRef weight_data,
183189
const ValueRef bias,
184-
const ValueRef out) {
190+
const ValueRef out,
191+
const float output_min = std::numeric_limits<float>::lowest(),
192+
const float output_max = std::numeric_limits<float>::max()) {
185193
VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim);
186194
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim);
187195

@@ -199,20 +207,21 @@ void add_conv1d_pw_node(
199207
ValueRef C_out_ref = graph.add_scalar(C_out);
200208
ValueRef has_bias_ref = graph.add_scalar(has_bias);
201209

202-
Conv1dPWIntParams int_params{1};
203-
Conv1dPWBiasParams bias_params{1.0f, 1.0f};
210+
Conv1dPWIntParams int_params{1, output_min, output_max};
211+
Conv1dPWBiasParams bias_params{1.0f, 1.0f, output_min, output_max};
204212

205213
std::vector<ValueRef> read_inputs = {in, packed_weight};
206214
if (has_bias) {
207215
read_inputs.push_back(packed_bias);
208216
}
209217

210-
std::vector<PushConstantDataInfo> push_constants = {
211-
PushConstantDataInfo(&int_params, sizeof(Conv1dPWIntParams)),
212-
};
218+
std::vector<PushConstantDataInfo> push_constants;
213219
if (has_bias) {
214220
push_constants.push_back(
215221
PushConstantDataInfo(&bias_params, sizeof(Conv1dPWBiasParams)));
222+
} else {
223+
push_constants.push_back(
224+
PushConstantDataInfo(&int_params, sizeof(Conv1dPWIntParams)));
216225
}
217226

218227
vkapi::ParamsBindList shader_params = {
@@ -240,20 +249,31 @@ void add_conv1d_pw_node(
240249
resize_conv1d_pw_node));
241250
}
242251

252+
// Args: in, weight, bias, stride, padding, dilation, groups,
253+
// output_min, output_max, out
254+
// output_min and output_max may be kDummyValueRef (no clamp).
243255
void conv1d_pw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
244-
// args: in, weight, bias, stride, padding, dilation, groups, out
245256
ValueRef in = args[0];
246257
ValueRef weight = args[1];
247258
ValueRef bias = args[2];
248-
ValueRef out = args[7];
259+
ValueRef out = args[9];
249260

250261
const std::vector<int64_t> weight_sizes = graph.sizes_of(weight);
251262
VK_CHECK_COND(
252263
weight_sizes.at(2) == 1, "conv1d_pw only supports kernel_size=1");
253264
VK_CHECK_COND(
254265
graph.get_int(args[6]) == 1, "conv1d_pw only supports groups=1");
255266

256-
add_conv1d_pw_node(graph, in, weight, bias, out);
267+
float output_min = std::numeric_limits<float>::lowest();
268+
float output_max = std::numeric_limits<float>::max();
269+
if (is_valid(args[7])) {
270+
output_min = graph.extract_scalar<float>(args[7]);
271+
}
272+
if (is_valid(args[8])) {
273+
output_max = graph.extract_scalar<float>(args[8]);
274+
}
275+
276+
add_conv1d_pw_node(graph, in, weight, bias, out, output_min, output_max);
257277
}
258278

259279
REGISTER_OPERATORS {

0 commit comments

Comments
 (0)