Skip to content

Commit 0934d48

Browse files
author
ssjia
committed
[ET-VK][conv1d] Implement height-packed depthwise conv1d operator
Implement a depthwise conv1d operator using height-packed layout where channels are the packed dimension (WHCN dim 1). Depthwise conv applies a separate filter to each channel independently (groups=C), so 4 channels can be processed in parallel using element-wise vec4 FMA over kernel positions. Thread mapping: X=C/4, Y=L_out, Z=N. Each thread computes one output texel (4 channels at one spatial position). Inner loop iterates over kernel positions K with bounds-checked input access for padding. Weight [C,1,K] is prepacked as channels-packed so each vec4 load gives 4 channels' weights at one kernel position. Supports both buffer and texture3d storage, fp32/fp16, optional bias, and arbitrary stride/padding/dilation. Registered as et_vk.conv1d_dw.default (standalone custom op). Performance on Adreno 750 (S24): - [1,128,4096] K=31 buffer f16: 231 GFLOP/s - [1,128,4096] K=31 buffer f32: 155 GFLOP/s - [1,512,2048] K=5 buffer f32: 66 GFLOP/s Differential Revision: [D97344091](https://our.internmc.facebook.com/intern/diff/D97344091/) [ghstack-poisoned]
1 parent b8ba505 commit 0934d48

6 files changed

Lines changed: 560 additions & 0 deletions

File tree

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
#define T ${texel_load_component_type(DTYPE, STORAGE)}
14+
15+
$if STORAGE == "buffer":
16+
#define BUFFER
17+
$if HAS_BIAS:
18+
#define HAS_BIAS
19+
20+
${define_required_extensions(STORAGE, DTYPE)}
21+
22+
layout(std430) buffer;
23+
24+
#include "common.glslh"
25+
26+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
27+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE, is_scalar_array=False)}
29+
$if HAS_BIAS:
30+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE, is_scalar_array=False)}
31+
32+
// in_sizes: {L_in, C, N, 1} in WHCN order
33+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
34+
// out_sizes: {L_out, C, N, 1} in WHCN order
35+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
36+
37+
layout(push_constant) uniform restrict Block {
38+
int kernel_size;
39+
int stride;
40+
int padding;
41+
int dilation;
42+
};
43+
44+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
45+
46+
// Thread mapping: X = C/4, Y = L_out, Z = N
47+
// Each thread computes 4 output channels at one spatial position.
48+
// Depthwise: each channel has its own filter, so 4 channels can be computed
49+
// independently with element-wise vec4 FMA.
50+
51+
void main() {
52+
const int c4 = int(gl_GlobalInvocationID.x);
53+
const int l_out = int(gl_GlobalInvocationID.y);
54+
const int n = int(gl_GlobalInvocationID.z);
55+
56+
const int L_in = in_sizes.x;
57+
const int C = in_sizes.y;
58+
const int C4 = div_up_4(C);
59+
const int L_out = out_sizes.x;
60+
61+
if (c4 >= C4 || l_out >= L_out) {
62+
return;
63+
}
64+
65+
VEC4_T sum = VEC4_T(0);
66+
67+
for (int k = 0; k < kernel_size; k++) {
68+
const int l_in = l_out * stride - padding + k * dilation;
69+
if (l_in >= 0 && l_in < L_in) {
70+
#ifdef BUFFER
71+
const VEC4_T in_val = t_in[(n * L_in + l_in) * C4 + c4];
72+
const VEC4_T w_val = t_weight[k * C4 + c4];
73+
#else
74+
const VEC4_T in_val = texelFetch(t_in, ivec3(l_in, c4, n), 0);
75+
const VEC4_T w_val = texelFetch(t_weight, ivec3(k, 0, c4), 0);
76+
#endif
77+
sum = fma(w_val, in_val, sum);
78+
}
79+
}
80+
81+
#ifdef HAS_BIAS
82+
#ifdef BUFFER
83+
sum += t_bias[c4];
84+
#else
85+
sum += texelFetch(t_bias, ivec3(c4, 0, 0), 0);
86+
#endif
87+
#endif
88+
89+
#ifdef BUFFER
90+
t_out[(n * L_out + l_out) * C4 + c4] = sum;
91+
#else
92+
imageStore(t_out, ivec3(l_out, c4, n), sum);
93+
#endif
94+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv1d_dw:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
HAS_BIAS: false
12+
generate_variant_forall:
13+
STORAGE:
14+
- VALUE: texture3d
15+
- VALUE: buffer
16+
DTYPE:
17+
- VALUE: float
18+
- VALUE: half
19+
shader_variants:
20+
- NAME: conv1d_dw
21+
- NAME: conv1d_dw_bias
22+
HAS_BIAS: true
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16+
17+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18+
19+
namespace vkcompute {
20+
21+
void resize_conv1d_dw_node(
22+
ComputeGraph* graph,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& extra_args) {
25+
const ValueRef out = args.at(0).refs.at(0);
26+
const ValueRef self = args.at(1).refs.at(0);
27+
28+
TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0));
29+
30+
const int64_t stride = graph->get_int_list(extra_args.at(1))->at(0);
31+
const int64_t padding = graph->get_int_list(extra_args.at(2))->at(0);
32+
const int64_t dilation = graph->get_int_list(extra_args.at(3))->at(0);
33+
34+
const std::vector<int64_t> in_sizes = graph->sizes_of(self);
35+
const int64_t kernel_size = weight_ref->sizes.at(2);
36+
const int64_t L_in = in_sizes.at(2);
37+
38+
const int64_t L_out =
39+
calc_out_size(L_in, kernel_size, stride, padding, dilation, false);
40+
41+
graph->virtual_resize(out, {in_sizes.at(0), in_sizes.at(1), L_out});
42+
}
43+
44+
struct Conv1dDWParams final {
45+
int32_t kernel_size;
46+
int32_t stride;
47+
int32_t padding;
48+
int32_t dilation;
49+
};
50+
51+
utils::uvec3 pick_conv1d_dw_global_wg_size(
52+
ComputeGraph* graph,
53+
const vkapi::ShaderInfo& shader,
54+
const std::vector<ArgGroup>& args,
55+
const std::vector<ValueRef>& resize_args) {
56+
(void)shader;
57+
(void)resize_args;
58+
const ValueRef out = args.at(0).refs.at(0);
59+
60+
// out is [N, C, L_out]; in WHCN: {L_out, C, N, 1}
61+
const uint32_t C = graph->size_at<uint32_t>(-2, out);
62+
const uint32_t L_out = graph->size_at<uint32_t>(-1, out);
63+
const uint32_t N =
64+
graph->dim_of(out) >= 3 ? graph->size_at<uint32_t>(-3, out) : 1;
65+
66+
return {utils::div_up_4(C), L_out, N};
67+
}
68+
69+
void add_conv1d_dw_node(
70+
ComputeGraph& graph,
71+
const ValueRef in,
72+
const ValueRef weight_data,
73+
const ValueRef bias,
74+
const ValueRef stride_ref,
75+
const ValueRef padding_ref,
76+
const ValueRef dilation_ref,
77+
const ValueRef out) {
78+
VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kHeightDim);
79+
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kHeightDim);
80+
81+
const utils::StorageType storage_type = graph.storage_type_of(out);
82+
83+
// Weight [C, 1, K] prepacked as channels-packed so each vec4 load gives
84+
// 4 channels at one kernel position.
85+
ValueRef packed_weight = prepack_standard(
86+
graph, weight_data, storage_type, utils::kChannelsPacked);
87+
88+
bool has_bias = graph.val_is_not_none(bias);
89+
ValueRef packed_bias = kDummyValueRef;
90+
if (has_bias) {
91+
packed_bias =
92+
prepack_standard(graph, bias, storage_type, utils::kWidthPacked);
93+
}
94+
95+
const auto stride_val = graph.get_int_list(stride_ref)->at(0);
96+
const auto padding_val = graph.get_int_list(padding_ref)->at(0);
97+
const auto dilation_val = graph.get_int_list(dilation_ref)->at(0);
98+
99+
Conv1dDWParams params{
100+
utils::safe_downcast<int32_t>(graph.get_tref(weight_data)->sizes.at(2)),
101+
utils::safe_downcast<int32_t>(stride_val),
102+
utils::safe_downcast<int32_t>(padding_val),
103+
utils::safe_downcast<int32_t>(dilation_val),
104+
};
105+
106+
std::string kernel_name = has_bias ? "conv1d_dw_bias" : "conv1d_dw";
107+
kernel_name.reserve(kShaderNameReserve);
108+
add_storage_type_suffix(kernel_name, storage_type);
109+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
110+
111+
std::vector<ValueRef> read_inputs = {in, packed_weight};
112+
if (has_bias) {
113+
read_inputs.push_back(packed_bias);
114+
}
115+
116+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
117+
graph,
118+
VK_KERNEL_FROM_STR(kernel_name),
119+
pick_conv1d_dw_global_wg_size,
120+
default_pick_local_wg_size,
121+
// Inputs and Outputs
122+
{{out, vkapi::kWrite}, {read_inputs, vkapi::kRead}},
123+
// Shader params buffers
124+
{graph.sizes_ubo(in), graph.sizes_ubo(out)},
125+
// Push Constants
126+
{PushConstantDataInfo(&params, sizeof(Conv1dDWParams))},
127+
// Specialization Constants
128+
{},
129+
// Resize Args
130+
{weight_data, stride_ref, padding_ref, dilation_ref},
131+
// Resizing Logic
132+
resize_conv1d_dw_node));
133+
}
134+
135+
void conv1d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
136+
// args: in, weight, bias, stride, padding, dilation, groups, out
137+
ValueRef in = args[0];
138+
ValueRef weight = args[1];
139+
ValueRef bias = args[2];
140+
ValueRef stride = args[3];
141+
ValueRef padding = args[4];
142+
ValueRef dilation = args[5];
143+
ValueRef out = args[7];
144+
145+
add_conv1d_dw_node(graph, in, weight, bias, stride, padding, dilation, out);
146+
}
147+
148+
REGISTER_OPERATORS {
149+
VK_REGISTER_OP(et_vk.conv1d_dw.default, conv1d_dw);
150+
}
151+
152+
} // namespace vkcompute
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
12+
13+
namespace vkcompute {
14+
15+
void test_conv1d_dw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
16+
// args: in, weight, bias, stride, padding, dilation, groups, out
17+
VK_GET_OP_FN("et_vk.conv1d_dw.default")(graph, args);
18+
}
19+
20+
REGISTER_OPERATORS {
21+
VK_REGISTER_OP(test_etvk.test_conv1d_dw.default, test_conv1d_dw);
22+
}
23+
24+
} // namespace vkcompute

backends/vulkan/test/custom_ops/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,4 @@ def define_common_targets(is_fbcode = False):
103103
define_custom_op_test_binary("test_conv2d_pw")
104104
define_custom_op_test_binary("test_conv2d_dw")
105105
define_custom_op_test_binary("test_conv1d_pw")
106+
define_custom_op_test_binary("test_conv1d_dw")

0 commit comments

Comments
 (0)