Skip to content

Commit 30727c4

Browse files
authored
Tile engine for streamk (#3157)
* [CK TILE STREAMK] Introduce initial support for tile engine in streamk GEMM. - This commit lays the groundwork for integrating the tile engine into streamk GEMM. It focuses on creating benchmark executables for streamk GEMM. - Additional scripts like test_benchmark.sh and gemm_benchmark.py will be added once the streamk implementation reaches stability. * [CK TILE STREAMK] Enable CI to execute tile engine benchmarks for StreamK GEMM * [CK TILE STREAMK] Refactor: Extract common utility functions. * [CK TILE STREAMK] Revise tile engine of streamk to align with the updated implementation * Add pre-commit * [CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK * [CK TILE STREAMK] Fix a bug about value of 'dp_persistent' of CK TILE STREAMK * [CK TILE STREAMK] Update Jenkinsfile * [CK TILE Engine] Update StreamK tile engine help message Remove default value messages as they are automatically printed * [CK TILE Engine] Update StreamK tile engine - Remove namespace reboot * [CK TILE Engine] Update StreamK tile engine - Fix merge error
1 parent 24d88d2 commit 30727c4

15 files changed

Lines changed: 2530 additions & 19 deletions

Jenkinsfile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,11 +1615,13 @@ pipeline {
16151615
-D GPU_TARGETS="gfx90a" \
16161616
-D GEMM_DATATYPE="fp8;fp16" \
16171617
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
1618+
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
1619+
-D GEMM_STREAMK_LAYOUT="rcr" \
16181620
-D GEMM_MULTI_D_DATATYPE="fp16" \
16191621
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
16201622
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
16211623
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
1622-
ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
1624+
ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
16231625
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
16241626
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
16251627
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """
@@ -1644,11 +1646,13 @@ pipeline {
16441646
-D GPU_TARGETS="gfx942" \
16451647
-D GEMM_DATATYPE="fp8;fp16" \
16461648
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
1649+
-D GEMM_STREAMK_DATATYPE="fp8;fp16" \
1650+
-D GEMM_STREAMK_LAYOUT="rcr" \
16471651
-D GEMM_MULTI_D_DATATYPE="fp16" \
16481652
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
16491653
-D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \
16501654
-D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \
1651-
ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \
1655+
ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \
16521656
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
16531657
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \
16541658
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """

example/ck_tile/40_streamk_gemm/run_gemm_example.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
8686

8787
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
8888

89-
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
89+
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
9090
{
9191
ave_time_and_batch = gemm<GemmConfig,
9292
ADataType,

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
105105
}
106106

107107
auto reset_data_buffers = [&]() {
108-
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
108+
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
109109
{
110110
// Clear the output C tensor results after each repetition of the kernel
111111
hipGetErrorString(hipMemsetAsync(
112112
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
113113
}
114-
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
114+
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
115115
{
116116
// Reset sk flags to zero before each repetition of the kernel
117117
workspace_data.SetZero();

include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
2828
index_t K_,
2929
index_t stride_A_,
3030
index_t stride_B_,
31-
index_t stride_C_,
32-
StreamKReductionStrategy reduction_strategy_)
31+
index_t stride_C_)
3332
: UniversalGemmHostArgs<>({a_ptr_},
3433
{b_ptr_},
3534
{/*ds_ptr*/},
@@ -41,12 +40,9 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
4140
{stride_A_},
4241
{stride_B_},
4342
{/*stride_Ds_*/},
44-
stride_C_),
45-
reduction_strategy{reduction_strategy_}
43+
stride_C_)
4644
{
4745
}
48-
49-
ck_tile::StreamKReductionStrategy reduction_strategy;
5046
};
5147

5248
/**
@@ -133,18 +129,13 @@ struct StreamKKernel
133129
host_args.stride_Ds,
134130
host_args.stride_E,
135131
host_args.k_batch},
136-
reduction_strategy{host_args.reduction_strategy},
137132
// The workspace pointer is set to nullptr because we must first
138133
// instantiate the TilePartitioner to get the necessary size
139134
workspace_ptr{nullptr},
140135
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
141136

142137
{
143138
}
144-
/**
145-
* @brief The strategy used by work groups to compute final results in C tensor.
146-
*/
147-
StreamKReductionStrategy reduction_strategy;
148139
/**
149140
* @brief A pointer to a buffer in device memory for accumulating partial via reduction
150141
* strategy.

test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,7 @@ class TestCkTileStreamK : public ::testing::Test
250250
K,
251251
stride_A,
252252
stride_B,
253-
stride_C,
254-
reduction_strategy};
253+
stride_C};
255254

256255
ck_tile::index_t num_accumulations_per_tile =
257256
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
7+
auto calculate_rtol_atol(const ck_tile::index_t K,
8+
const ck_tile::index_t kbatch,
9+
const float max_accumulated_value)
10+
{
11+
using ComputeType =
12+
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
13+
// Calculate thresholds
14+
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
15+
ck_tile::integer_divide_ceil(K, kbatch));
16+
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
17+
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
18+
// Calculate error due to split_k accumulation
19+
const auto rtol_split_k =
20+
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
21+
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
22+
max_accumulated_value, kbatch);
23+
// Use higher threshold
24+
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
25+
}
26+
27+
/// @brief Function to compare the results of the device and host computations
28+
bool compare(std::string instanceName,
29+
ck_tile::index_t K,
30+
ck_tile::index_t kbatch,
31+
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
32+
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
33+
{
34+
const float max_accumulated_value =
35+
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
36+
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
37+
K, kbatch, max_accumulated_value);
38+
bool pass = ck_tile::check_err(c_m_n_dev_result,
39+
c_m_n_host_result,
40+
"Error: Incorrect results!",
41+
rtol_atol.at(ck_tile::number<0>{}),
42+
rtol_atol.at(ck_tile::number<1>{}));
43+
44+
std::cout << "For " << instanceName << " Relative error threshold is "
45+
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
46+
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
47+
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
48+
49+
return pass;
50+
}

tile_engine/ops/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(gemm)
22
add_subdirectory(gemm_multi_d)
3-
add_subdirectory(gemm_preshuffle)
3+
add_subdirectory(gemm_preshuffle)
4+
add_subdirectory(gemm_streamk)

0 commit comments

Comments
 (0)