Commit 6989cf8
[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
grouped conv (#6327)
## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.
## Technical Details
1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
- **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
- GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
- Added _ensure_initialized() method for lazy GPU loading
- GPU context created only on first run() call
- **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
- Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison
## Test Plan
1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.
## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **93.05%**
- Median Efficiency: **96.8%**
- P10 Efficiency: **79.9%**
#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **93.8%**
- Median Efficiency: **96.5%**
- P10 Efficiency: **82.9%**
#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
- Mean Efficiency: **96.1%**
- Median Efficiency: **99.2%**
- P10 Efficiency: **89.4%**
## Submission Checklist
- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.1 parent b05040b commit 6989cf8
65 files changed
Lines changed: 13204 additions & 387 deletions
File tree
- dispatcher
- bindings/ctypes
- codegen
- examples/grouped_conv/python
- heuristics
- models
- grouped_conv_bwd_data_bf16_gfx950
- grouped_conv_bwd_weight_bf16_gfx950
- grouped_conv_forward_2d3d_suffix_bf16_gfx950
- grouped_conv_forward_bf16_gfx950
- tests
- validation
- grouped_conv
- include/ck_tile/dispatcher
- python
- tile_engine/ops/grouped_conv
- problems
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
132 | | - | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
133 | 148 | | |
134 | 149 | | |
135 | 150 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
81 | 81 | | |
82 | 82 | | |
83 | 83 | | |
84 | | - | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
85 | 87 | | |
86 | 88 | | |
87 | 89 | | |
| |||
256 | 258 | | |
257 | 259 | | |
258 | 260 | | |
259 | | - | |
260 | | - | |
| 261 | + | |
| 262 | + | |
261 | 263 | | |
262 | 264 | | |
263 | 265 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
2 | 1 | | |
3 | 2 | | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
7 | 6 | | |
8 | | - | |
| 7 | + | |
9 | 8 | | |
10 | 9 | | |
11 | 10 | | |
| |||
50 | 49 | | |
51 | 50 | | |
52 | 51 | | |
53 | | - | |
| 52 | + | |
54 | 53 | | |
55 | 54 | | |
56 | 55 | | |
| |||
226 | 225 | | |
227 | 226 | | |
228 | 227 | | |
| 228 | + | |
| 229 | + | |
229 | 230 | | |
230 | 231 | | |
231 | 232 | | |
232 | 233 | | |
233 | 234 | | |
234 | 235 | | |
235 | 236 | | |
| 237 | + | |
| 238 | + | |
236 | 239 | | |
237 | 240 | | |
238 | 241 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
230 | 230 | | |
231 | 231 | | |
232 | 232 | | |
233 | | - | |
| 233 | + | |
234 | 234 | | |
235 | 235 | | |
236 | 236 | | |
| |||
288 | 288 | | |
289 | 289 | | |
290 | 290 | | |
291 | | - | |
292 | | - | |
| 291 | + | |
| 292 | + | |
293 | 293 | | |
294 | 294 | | |
295 | 295 | | |
296 | | - | |
| 296 | + | |
297 | 297 | | |
298 | 298 | | |
299 | 299 | | |
| |||
0 commit comments