Skip to content

Commit 6989cf8

Browse files
yrapartiassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent b05040b commit 6989cf8

65 files changed

Lines changed: 13204 additions & 387 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,22 @@ float conv_bwdw_run(const void* input_ptr,
129129
return -1.0f;
130130
if(!input_ptr || !grad_output_ptr || !grad_weight_ptr)
131131
return -1.0f; // Null data pointer would cause kernel crash
132-
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
132+
133+
try
134+
{
135+
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
136+
}
137+
catch(const std::exception&)
138+
{
139+
// Kernel rejected args (e.g. unsupported tile/channel combo)
140+
// -3.0f matches conv_ctypes_lib.cpp:316 convention
141+
// -2.0f is reserved for "no kernel / not compiled for this direction"
142+
return -3.0f;
143+
}
144+
catch(...)
145+
{
146+
return -3.0f;
147+
}
133148
#else
134149
return -1.0f;
135150
#endif

dispatcher/codegen/arch_specs.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@
8181
"warp_configs": [
8282
[1, 4, 1],
8383
[2, 2, 1],
84-
[4, 1, 1]
84+
[4, 1, 1],
85+
[8, 2, 1],
86+
[4, 4, 1]
8587
],
8688
"warp_tile_combos": {
8789
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
@@ -256,8 +258,8 @@
256258
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
257259
},
258260
"gfx950": {
259-
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
260-
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
261+
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
262+
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16], [32, 32, 32], [16, 16, 64]],
261263
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
262264
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
263265
}

dispatcher/codegen/arch_specs_generated.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
21
# SPDX-License-Identifier: MIT
32

43
"""
54
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
65
76
Generated from: arch_specs.json
8-
Generated at: 2026-01-05T19:34:01.224422
7+
Generated at: 2026-04-10T20:07:11.665064
98
109
To update this file:
1110
1. Edit arch_specs.json
@@ -50,7 +49,7 @@
5049
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
5150
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
5251
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
53-
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
52+
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1], [8, 2, 1], [4, 4, 1]],
5453
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
5554
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
5655
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
@@ -226,13 +225,17 @@
226225
[32, 32, 16],
227226
[16, 16, 32],
228227
[64, 4, 16],
228+
[32, 32, 32],
229+
[16, 16, 64],
229230
],
230231
"bf16_bf16_fp32": [
231232
[32, 32, 8],
232233
[16, 16, 16],
233234
[32, 32, 16],
234235
[16, 16, 32],
235236
[64, 4, 16],
237+
[32, 32, 32],
238+
[16, 16, 64],
236239
],
237240
"fp8_fp8_fp32": [
238241
[32, 32, 16],

dispatcher/codegen/generate_arch_specs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
230230

231231
for arch, data in archs.items():
232232
enum_name = arch.upper().replace("GFX", "GFX_")
233-
arch_enums.append(f" {enum_name}, // {data['description']}")
233+
arch_enums.append(f" {enum_name},")
234234
arch_to_string_cases.append(
235235
f' case GpuArch::{enum_name}: return "{arch}";'
236236
)
@@ -288,12 +288,12 @@ def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
288288
f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};"
289289
)
290290

291-
content = f"""// SPDX-License-Identifier: MIT
292-
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
291+
content = f"""// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
292+
// SPDX-License-Identifier: MIT
293293
294294
/**
295295
* AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
296-
*
296+
*
297297
* Generated from: arch_specs.json
298298
* Generated at: {timestamp}
299299
*

0 commit comments

Comments
 (0)