Skip to content

Commit eb04107

Browse files
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4 * fix: removed extra parameter from grouped gemm example instance * fix: FP8 check incorrectly enabling FP8 on RDNA3
1 parent 141f77a commit eb04107

44 files changed

Lines changed: 3067 additions & 1223 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

example/15_grouped_gemm/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
4444
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
4545
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
4646

47+
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
48+
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
49+
4750
list(APPEND gpu_list_tf32 gfx942 gfx950)
4851
set(target 0)
4952
foreach(gpu IN LISTS GPU_TARGETS)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <iostream>
5+
#include <numeric>
6+
#include <initializer_list>
7+
#include <cstdlib>
8+
9+
#include "ck/ck.hpp"
10+
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
11+
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
12+
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
13+
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
14+
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
15+
16+
#include <ck/utility/data_type.hpp>
17+
#include <ck/utility/tuple.hpp>
18+
19+
#include "ck/library/utility/check_err.hpp"
20+
#include "ck/library/utility/device_memory.hpp"
21+
#include "ck/library/utility/host_tensor.hpp"
22+
#include "ck/library/utility/host_tensor_generator.hpp"
23+
#include "ck/library/utility/literals.hpp"
24+
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
25+
26+
using ::ck::DeviceMem;
27+
using ::ck::hip_check_error;
28+
using ::ck::HostTensorDescriptor;
29+
using ::ck::Tensor;
30+
31+
template <ck::index_t... Is>
32+
using S = ck::Sequence<Is...>;
33+
34+
using F16 = ck::half_t;
35+
using F32 = float;
36+
37+
using Row = ck::tensor_layout::gemm::RowMajor;
38+
using Col = ck::tensor_layout::gemm::ColumnMajor;
39+
40+
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
41+
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
42+
43+
using ADataType = F16;
44+
using BDataType = F16;
45+
using AccDataType = F32;
46+
using CShuffleDataType = F32;
47+
using DDataType = F16;
48+
using DsDataType = ck::Tuple<DDataType, DDataType>;
49+
using EDataType = F16;
50+
51+
using ALayout = Row;
52+
using BLayout = Col;
53+
using DLayout = Row;
54+
using DsLayout = ck::Tuple<DLayout, DLayout>;
55+
using ELayout = Row;
56+
57+
using AElementOp = PassThrough;
58+
using BElementOp = PassThrough;
59+
using CDEElementOp = AddAdd;
60+
61+
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
62+
static constexpr int NumDs = 2;
63+
64+
using DeviceGemmInstance =
65+
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
66+
// clang-format off
67+
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
68+
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
69+
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
70+
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
71+
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
72+
// clang-format on
73+
74+
#include "run_grouped_gemm_multiple_d_example.inc"
75+
76+
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

0 commit comments

Comments
 (0)