Skip to content

Commit 920acd2

Browse files
[rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb)
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 4c0e73a commit 920acd2

86 files changed

Lines changed: 15501 additions & 1463 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

dispatcher/README.md

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# CK Tile Dispatcher
22

3-
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends.
3+
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations.
44

55
**Validated Platform:** AMD Instinct MI300 series (gfx942)
66

@@ -342,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so
342342
| `CMAKE_PREFIX_PATH` | - | ROCm installation path |
343343
| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler |
344344

345-
⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
346-
⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
345+
WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
346+
WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
347347

348348
---
349349

@@ -363,6 +363,15 @@ cd build/examples
363363
./gemm_04_heuristics # Heuristic kernel selection
364364
./gemm_05_json_export # Registry JSON export
365365
./gemm_06_multi_registry # Multiple registries
366+
367+
# Grouped Convolution Examples
368+
./grouped_conv_01_basic # Declaration patterns + GPU execution
369+
./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU
370+
./grouped_conv_03_bench_val # Benchmark + CPU reference validation
371+
./grouped_conv_04_registry_json # Heuristic selection + JSON export
372+
./grouped_conv_05_bwd_data # Backward data + CPU validation
373+
./grouped_conv_06_bwd_weight # Backward weight + CPU validation
374+
./grouped_conv_07_benchmark # Multi-tile ResNet benchmark
366375
```
367376

368377
### Python Examples
@@ -375,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher
375384
# GEMM Examples
376385
python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM
377386
python3 examples/gemm/python/04_validation.py # CPU reference validation
378-
python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels)
387+
python3 examples/gemm/python/07_stress_test.py # Stress test
379388
python3 examples/gemm/python/08_heuristics.py # Heuristic selection
389+
390+
# Grouped Convolution Examples
391+
python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU
392+
python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref
393+
python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref
394+
python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref
395+
python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark
396+
python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON
380397
```
381398

382399
### Example Output
@@ -647,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so")
647664
### Data Flow
648665

649666
```
650-
KernelConfig Registry Dispatcher GPU Execution
667+
KernelConfig -> Registry -> Dispatcher -> GPU Execution
651668
```
652669

653670
1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts)
@@ -843,31 +860,49 @@ make -j$(nproc)
843860

844861
```
845862
dispatcher/
846-
├── README.md # This file
847-
├── CMakeLists.txt # Build configuration
848-
849-
├── include/ck_tile/dispatcher/ # C++ headers
850-
│ ├── dispatcher.hpp # GEMM dispatcher
851-
│ ├── registry.hpp # Kernel registry
852-
│ └── kernel_key.hpp # Kernel configuration
853-
854-
├── src/ # C++ implementation
855-
856-
├── codegen/ # Kernel generation
857-
│ ├── unified_gemm_codegen.py # GEMM kernel generator
858-
│ └── arch_specs.json # GPU specifications
859-
860-
├── bindings/ctypes/ # Python ctypes interface
861-
│ └── gemm_ctypes_lib.cpp # GEMM Python library
862-
863-
├── examples/ # Examples
864-
│ └── gemm/
865-
│ ├── cpp/ # C++ GEMM examples (01-06)
866-
│ └── python/ # Python GEMM examples (01-11)
867-
868-
├── scripts/ # Build scripts
869-
870-
└── tests/ # Unit tests
863+
|---- README.md # This file
864+
|---- CMakeLists.txt # Build configuration
865+
|
866+
|---- include/ck_tile/dispatcher/ # C++ headers
867+
| |---- dispatcher.hpp # Main dispatcher include
868+
| |---- registry.hpp # GEMM kernel registry
869+
| |---- kernel_key.hpp # Kernel configuration
870+
| |---- grouped_conv_config.hpp # Grouped conv configuration
871+
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
872+
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
873+
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
874+
| +---- grouped_conv_utils.hpp # Grouped conv utilities
875+
|
876+
|---- src/ # C++ implementation
877+
|
878+
|---- codegen/ # Kernel generation
879+
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
880+
| |---- unified_gemm_codegen.py # GEMM kernel generator
881+
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
882+
| +---- arch_specs.json # GPU specifications
883+
|
884+
|---- python/ # Python utilities
885+
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
886+
| |---- ctypes_utils.py # GEMM ctypes utilities
887+
| +---- grouped_conv_utils.py # Grouped conv utilities
888+
|
889+
|---- scripts/ # Build scripts
890+
| |---- compile_gemm_examples.py # GEMM build script
891+
| +---- compile_grouped_conv_examples.py # Grouped conv build script
892+
|
893+
|---- bindings/ctypes/ # Python ctypes interface
894+
| |---- gemm_ctypes_lib.cpp # GEMM Python library
895+
| +---- conv_ctypes_lib.cpp # Grouped conv Python library
896+
|
897+
|---- examples/ # Examples
898+
| |---- gemm/
899+
| | |---- cpp/ # C++ GEMM examples (01-07)
900+
| | +---- python/ # Python GEMM examples (01-11)
901+
| +---- grouped_conv/
902+
| |---- cpp/ # C++ Grouped Conv examples (01-07)
903+
| +---- python/ # Python Grouped Conv examples (01-06)
904+
|
905+
+---- tests/ # Unit tests (C++ and Python)
871906
```
872907

873908
---
@@ -879,17 +914,49 @@ dispatcher/
879914
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
880915
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
881916
| Codegen | [codegen/README.md](codegen/README.md) |
917+
| Python Utils | [python/README.md](python/README.md) |
918+
| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) |
882919

883920
---
884921

885-
## Archived Content
922+
## Grouped Convolution Support
923+
924+
Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication.
925+
926+
### Python
927+
928+
```bash
929+
# Generate grouped conv kernels
930+
python3 codegen/unified_grouped_conv_codegen.py \
931+
--output-dir build/generated_kernels \
932+
--datatype fp16 --variant forward --ndim-spatial 2
933+
934+
# Build grouped conv examples
935+
python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
936+
```
937+
938+
### Key Files
939+
940+
| Component | File |
941+
|-----------|------|
942+
| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` |
943+
| Python Codegen | `codegen/unified_grouped_conv_codegen.py` |
944+
| Python Utils | `python/grouped_conv_utils.py` |
945+
| Build Script | `scripts/compile_grouped_conv_examples.py` |
946+
| Shared Codegen | `codegen/codegen_common.py` |
947+
| Shared Utils | `python/dispatcher_common.py` |
948+
949+
### Variants
950+
951+
- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution
952+
- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input
953+
- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights
954+
955+
### Shared Infrastructure
886956

887-
Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`:
888-
- `examples/conv/cpp/` - 11 C++ convolution examples
889-
- `examples/conv/python/` - 14 Python convolution examples
890-
- `codegen/unified_conv_codegen.py` - Conv kernel generator
891-
- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers
892-
- `python/conv_utils.py` - Conv Python utilities
957+
GEMM and grouped convolution share common code to avoid duplication:
958+
- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion
959+
- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output
893960

894961
---
895962

dispatcher/bindings/README.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher.
66

77
```
88
bindings/
9-
├── ctypes/ # Python ctypes bindings (C API)
10-
├── gemm_ctypes_lib.cpp # GEMM dispatcher C API
11-
├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data)
12-
├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API
13-
├── gpu_helper.cpp # CLI helper for Python
14-
└── CMakeLists.txt
15-
└── README.md
9+
|---- ctypes/ # Python ctypes bindings (C API)
10+
| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
11+
| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
12+
| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
13+
| |---- gpu_helper.cpp # CLI helper for Python
14+
| +---- CMakeLists.txt
15+
+---- README.md
1616
```
1717

1818
## ctypes Bindings
@@ -65,7 +65,7 @@ lib.dispatcher_cleanup()
6565
| `dispatcher_export_registry_json()` | Export registry as JSON |
6666
| `dispatcher_cleanup()` | Release resources |
6767

68-
### Convolution API
68+
### Grouped Convolution API
6969

7070
| Function | Description |
7171
|----------|-------------|
@@ -105,5 +105,11 @@ Output is JSON for easy parsing:
105105
See the examples that use these bindings:
106106

107107
- **GEMM**: `dispatcher/examples/gemm/python/`
108-
- **Conv**: `dispatcher/examples/conv/python/`
108+
109+
### Grouped Convolution
110+
111+
Grouped convolution C++ headers and Python utilities are in:
112+
- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp`
113+
- **Python Utils**: `dispatcher/python/grouped_conv_utils.py`
114+
- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py`
109115

dispatcher/bindings/ctypes/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ endif()
7878
# Look for forward kernels
7979
file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp")
8080
# Look for backward data kernels
81-
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp")
81+
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp")
8282
# Fallback: any conv kernel (for backwards compatibility)
8383
file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp")
8484

@@ -112,7 +112,7 @@ endif()
112112
# Add backward data kernel if available
113113
if(CONV_BWDD_KERNEL_HEADERS)
114114
list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER)
115-
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}")
115+
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}")
116116
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER})
117117
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE)
118118
endif()

dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ struct ConvBwdwProblemC
5353
int stride_d, stride_h, stride_w;
5454
int pad_d, pad_h, pad_w;
5555
int dilation_d, dilation_h, dilation_w;
56+
int split_k;
5657
};
5758

5859
// =============================================================================
@@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr,
108109
grad_weight_ptr, // wei_ptr = grad_weight (output)
109110
{}, // ds_ptr
110111
grad_output_ptr, // out_ptr = grad_output
111-
1 // k_batch
112-
);
112+
(prob->split_k > 1) ? prob->split_k : 1);
113113

114114
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
115115

0 commit comments

Comments
 (0)