Skip to content

Commit 8c4cb4f

Browse files
Jimmyniu9797mkumar16-amdThomasNing
authored
Jimniu/ ck tile gemm stride validation (#2710)
* Add stride validation for gemm_basic * change default stride statement * Fix build error * Fix pre-commit failure * Addressed PR comments * clear the redundant code * clang format --------- Co-authored-by: mkumar16-amd <mkumar16@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
1 parent 1e77695 commit 8c4cb4f

4 files changed

Lines changed: 77 additions & 0 deletions

File tree

example/ck_tile/03_gemm/gemm_basic.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,32 @@
55
#include "run_gemm_example.inc"
66
#include "run_gemm_example_common.hpp"
77
#include "gemm_basic_invoker.hpp"
8+
#include "ck_tile/core/utility/gemm_validation.hpp"
89

910
int run_gemm_example(ck_tile::ArgParser& arg_parser)
1011
{
1112
std::string data_type = arg_parser.get_str("prec");
1213
std::string a_layout = arg_parser.get_str("a_layout");
1314
std::string b_layout = arg_parser.get_str("b_layout");
15+
std::string c_layout = arg_parser.get_str("c_layout");
16+
17+
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> gemm_sizes =
18+
parse_gemm_size(arg_parser);
19+
20+
int m = std::get<0>(gemm_sizes);
21+
int n = std::get<1>(gemm_sizes);
22+
int k = std::get<2>(gemm_sizes);
23+
24+
int stride_a = arg_parser.get_int("stride_a");
25+
int stride_b = arg_parser.get_int("stride_b");
26+
int stride_c = arg_parser.get_int("stride_c");
1427

1528
using GemmConfig = GemmConfigBase;
1629
using Invoker = BasicInvoker;
1730

31+
ck_tile::validate_gemm_stride(
32+
a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c);
33+
1834
if(data_type == "fp16")
1935
{
2036
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::half_t>(

example/ck_tile/03_gemm/run_gemm_example.inc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
254254
return pass;
255255
}
256256

257+
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
258+
parse_gemm_size(ck_tile::ArgParser& arg_parser)
259+
{
260+
ck_tile::index_t M = arg_parser.get_int("m");
261+
ck_tile::index_t N = arg_parser.get_int("n");
262+
ck_tile::index_t K = arg_parser.get_int("k");
263+
return std::make_tuple(M, N, K);
264+
}
265+
257266
template <typename GemmConfig,
258267
typename Invoker,
259268
typename ADataType,

include/ck_tile/core.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
#include "ck_tile/core/utility/env.hpp"
7272
#include "ck_tile/core/utility/functional.hpp"
7373
#include "ck_tile/core/utility/functional_with_tuple.hpp"
74+
#include "ck_tile/core/utility/gemm_validation.hpp"
7475
#include "ck_tile/core/utility/ignore.hpp"
7576
#include "ck_tile/core/utility/literals.hpp"
7677
#include "ck_tile/core/utility/magic_div.hpp"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <stdexcept>
8+
#include "ck_tile/core/config.hpp"
9+
10+
namespace ck_tile {
11+
12+
inline void
13+
validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
14+
{
15+
if(Layout == "C" && stride < M)
16+
{
17+
throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
18+
std::to_string(stride) + ") must be greater or equal to dim " +
19+
std::to_string(M));
20+
}
21+
if(Layout == "R" && stride < N)
22+
{
23+
throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
24+
std::to_string(stride) + ") must be greater or equal to dim " +
25+
std::to_string(N));
26+
}
27+
}
28+
29+
inline void validate_gemm_stride(std::string a_layout,
30+
std::string b_layout,
31+
std::string c_layout,
32+
int M,
33+
int N,
34+
int K,
35+
int Stride_A,
36+
int Stride_B,
37+
int Stride_C)
38+
{
39+
// set default stride
40+
if(Stride_A <= 0)
41+
Stride_A = (a_layout == "R") ? K : M;
42+
if(Stride_B <= 0)
43+
Stride_B = (b_layout == "R") ? N : K;
44+
if(Stride_C <= 0)
45+
Stride_C = (c_layout == "R") ? N : M;
46+
47+
validate_stride(a_layout, M, K, Stride_A, "Stride_A");
48+
validate_stride(b_layout, K, N, Stride_B, "Stride_B");
49+
validate_stride(c_layout, M, N, Stride_C, "Stride_C");
50+
}
51+
} // namespace ck_tile

0 commit comments

Comments
 (0)