Skip to content

Commit 10f6c36

Browse files
ssjiachizkiyahu
authored andcommitted
[ET-VK][qconv] Dynamically select between im2col path and general path
Pull Request resolved: pytorch#17387 This adds a dispatch layer to `q8ta_conv2d` that dynamically selects between the im2col-based and general convolution implementations at graph build time. The existing `q8ta_conv2d` function is renamed to `q8ta_conv2d_general`, and a new `q8ta_conv2d` dispatcher is introduced that chooses the im2col path when the convolution is non-grouped, has input channels divisible by 4, and kernel size ≤ 3x3. All other cases fall through to the general path. A separate `q8ta_conv2d_general` operator is also registered so tests can directly invoke the general path for comparison. The test suite is updated to exercise both the general and im2col implementations explicitly, and the default impl_selector is changed from "general" to empty (which triggers the new dispatcher). FP buffer storage types are removed from the test matrix since they are not needed. ghstack-source-id: 340983078 @exported-using-ghexport Differential Revision: [D93000162](https://our.internmc.facebook.com/intern/diff/D93000162/)
1 parent 8f6fe44 commit 10f6c36

6 files changed

Lines changed: 52 additions & 23 deletions

File tree

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ void add_q8ta_conv2d_node(
323323
// High level operator impl
324324
//
325325

326-
void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
326+
void q8ta_conv2d_general(
327+
ComputeGraph& graph,
328+
const std::vector<ValueRef>& args) {
327329
int32_t idx = 0;
328330
const ValueRef packed_int8_input = args.at(idx++);
329331
const ValueRef input_scale = args.at(idx++);
@@ -398,8 +400,30 @@ void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
398400
packed_int8_output);
399401
}
400402

403+
void q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
404+
// Index into args to extract values needed for dispatch decision
405+
const ValueRef packed_int8_input = args.at(0);
406+
const ValueRef kernel_size = args.at(9);
407+
const ValueRef groups = args.at(13);
408+
409+
const int32_t groups_val = graph.get_int(groups);
410+
const int64_t IC = graph.size_at<int64_t>(-3, packed_int8_input);
411+
412+
const int64_t K_h = graph.get_int_list(kernel_size)->at(0);
413+
const int64_t K_w = graph.get_int_list(kernel_size)->at(1);
414+
415+
// Use im2col path when: non-grouped, input channels multiple of 4, small
416+
// kernel
417+
if (groups_val == 1 && IC % 4 == 0 && K_h <= 3 && K_w <= 3) {
418+
q8ta_conv2d_im2col(graph, args);
419+
} else {
420+
q8ta_conv2d_general(graph, args);
421+
}
422+
}
423+
401424
REGISTER_OPERATORS {
402425
VK_REGISTER_OP(etvk.q8ta_conv2d.default, q8ta_conv2d);
426+
VK_REGISTER_OP(etvk.q8ta_conv2d_general.default, q8ta_conv2d_general);
403427
}
404428

405429
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,6 @@ void add_q8ta_conv2d_pw_node(
113113
const ValueRef packed_bias,
114114
const ValueRef packed_int8_output);
115115

116+
void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector<ValueRef>& args);
117+
116118
} // namespace vkcompute

backends/vulkan/test/custom_ops/impl/TestQ8taConv2d.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ void test_q8ta_conv2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
157157
} else if (impl_selector == "im2col") {
158158
// Use the im2col-based conv2d operator
159159
VK_GET_OP_FN("etvk.q8ta_conv2d_im2col.default")(graph, conv_args);
160+
} else if (impl_selector == "general") {
161+
// Use the general q8ta_conv2d operator (no im2col dispatch)
162+
VK_GET_OP_FN("etvk.q8ta_conv2d_general.default")(graph, conv_args);
160163
} else {
161164
// Use the new general q8ta_conv2d operator
162165
VK_GET_OP_FN("etvk.q8ta_conv2d.default")(graph, conv_args);

backends/vulkan/test/custom_ops/test_q8ta_conv2d.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static TestCase create_test_case_from_config(
2929
vkapi::ScalarType input_dtype,
3030
utils::StorageType fp_storage_type,
3131
utils::GPUMemoryLayout int8_memory_layout,
32-
const std::string& impl_selector = "general") {
32+
const std::string& impl_selector = "") {
3333
TestCase test_case;
3434

3535
// Calculate output dimensions
@@ -53,7 +53,6 @@ static TestCase create_test_case_from_config(
5353
std::to_string(config.input_size.w) + " " +
5454
"g=" + std::to_string(config.groups) + " " +
5555
"k=" + std::to_string(config.kernel.h) + " " +
56-
repr_str(fp_storage_type, fp_memory_layout) + "->" +
5756
repr_str(utils::kBuffer, int8_memory_layout);
5857
if (!impl_selector.empty()) {
5958
test_name += " [" + impl_selector + "]";
@@ -218,8 +217,7 @@ std::vector<TestCase> generate_quantized_conv2d_easy_cases() {
218217
};
219218
config.op_name = "conv2d_q8ta_q8csw_q8to";
220219

221-
std::vector<utils::StorageType> fp_storage_types = {
222-
utils::kTexture3D, utils::kBuffer};
220+
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
223221

224222
// Memory layouts for int8 tensors - test both optimized (4W4C) and general
225223
// paths
@@ -379,8 +377,7 @@ static std::vector<TestCase> generate_quantized_conv2d_test_cases() {
379377
4}};
380378

381379
// Test with different storage types and memory layouts
382-
std::vector<utils::StorageType> fp_storage_types = {
383-
utils::kTexture3D, utils::kBuffer};
380+
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
384381

385382
// Memory layouts for int8 tensors - test both optimized (4W4C) and general
386383
// paths
@@ -401,29 +398,37 @@ static std::vector<TestCase> generate_quantized_conv2d_test_cases() {
401398
int8_memory_layouts) {
402399
config.test_case_name = make_test_case_name(
403400
config, is_performance, fp_storage_type, utils::kBuffer);
401+
404402
test_cases.push_back(create_test_case_from_config(
405-
config, vkapi::kFloat, fp_storage_type, int8_memory_layout));
403+
config,
404+
vkapi::kFloat,
405+
fp_storage_type,
406+
int8_memory_layout,
407+
/*impl_selector=*/"general"));
406408

407-
// For 4W4C layout, also test the legacy implementation
408-
if (int8_memory_layout == utils::kPackedInt8_4W4C) {
409+
// Test im2col implementation for non-grouped convolutions with input
410+
// channels that are a multiple of 4 and stride_w == 1
411+
if (config.groups == 1 && config.channels.in % 4 == 0) {
409412
test_cases.push_back(create_test_case_from_config(
410413
config,
411414
vkapi::kFloat,
412415
fp_storage_type,
413416
int8_memory_layout,
414-
/*impl_selector=*/"legacy_4w4c"));
417+
/*impl_selector=*/"im2col"));
415418
}
416419

417-
// Test im2col implementation for non-grouped convolutions with input
418-
// channels that are a multiple of 4 and stride_w == 1
419-
if (config.groups == 1 && config.channels.in % 4 == 0) {
420+
// For 4W4C layout, also test the legacy implementation
421+
if (int8_memory_layout == utils::kPackedInt8_4W4C) {
420422
test_cases.push_back(create_test_case_from_config(
421423
config,
422424
vkapi::kFloat,
423425
fp_storage_type,
424426
int8_memory_layout,
425-
/*impl_selector=*/"im2col"));
427+
/*impl_selector=*/"legacy_4w4c"));
426428
}
429+
430+
test_cases.push_back(create_test_case_from_config(
431+
config, vkapi::kFloat, fp_storage_type, int8_memory_layout));
427432
}
428433
}
429434
}

backends/vulkan/test/custom_ops/test_q8ta_conv2d_dw.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ TestCase create_test_case_from_config(
5454
std::to_string(config.input_size.w) + " " +
5555
"g=" + std::to_string(config.groups) + " " +
5656
"k=" + std::to_string(config.kernel.h) + " " +
57-
repr_str(fp_storage_type, fp_memory_layout) + "->" +
5857
repr_str(utils::kBuffer, int8_memory_layout);
5958
if (!impl_selector.empty()) {
6059
test_name += " [" + impl_selector + "]";
@@ -228,8 +227,7 @@ std::vector<TestCase> generate_quantized_conv2d_dw_easy_cases() {
228227
};
229228
config.op_name = "conv2d_q8ta_q8csw_q8to";
230229

231-
std::vector<utils::StorageType> fp_storage_types = {
232-
utils::kTexture3D, utils::kBuffer};
230+
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
233231

234232
// Memory layouts for int8 tensors - test both optimized (4W4C) and general
235233
// paths
@@ -351,8 +349,7 @@ std::vector<TestCase> generate_quantized_conv2d_dw_test_cases() {
351349
32}};
352350

353351
// Test with different storage types and data types
354-
std::vector<utils::StorageType> fp_storage_types = {
355-
utils::kTexture3D, utils::kBuffer};
352+
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
356353

357354
// Memory layouts for int8 tensors - test both optimized (4W4C) and general
358355
// paths

backends/vulkan/test/custom_ops/test_q8ta_conv2d_pw.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ static TestCase create_test_case_from_config(
5353
std::to_string(config.input_size.w) + " " +
5454
"g=" + std::to_string(config.groups) + " " +
5555
"k=" + std::to_string(config.kernel.h) + " " +
56-
repr_str(fp_storage_type, fp_memory_layout) + "->" +
5756
repr_str(utils::kBuffer, int8_memory_layout);
5857
if (!impl_selector.empty()) {
5958
test_name += " [" + impl_selector + "]";
@@ -286,8 +285,7 @@ static std::vector<TestCase> generate_quantized_conv2d_pw_test_cases() {
286285
};
287286

288287
// Test with different storage types and memory layouts
289-
std::vector<utils::StorageType> fp_storage_types = {
290-
utils::kTexture3D, utils::kBuffer};
288+
std::vector<utils::StorageType> fp_storage_types = {utils::kTexture3D};
291289

292290
// Memory layouts for int8 tensors - test both optimized (4W4C) and general
293291
// paths

0 commit comments

Comments
 (0)