Skip to content

Commit 24d88d2

Browse files
arai713cgmillette
andauthored
[CK_TILE] Move DataTypeTraits into a Common File (#3146)
This renames the typeToStr struct in the common utilities to DataTypeTraits and removes all duplication of DataTypeTraits across files in CK Tile. Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
1 parent 678298d commit 24d88d2

17 files changed

Lines changed: 93 additions & 473 deletions

File tree

example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <tuple>
1111

1212
#include "ck_tile/host.hpp"
13+
#include "ck_tile/ops/common/utils.hpp"
1314
#include "ck_tile/ops/reduce.hpp"
1415
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
1516
#include "gemm_utils.hpp"
@@ -589,9 +590,10 @@ float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf,
589590
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
590591
<< " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes"
591592
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
592-
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
593-
<< " B_Type=" << DataTypeTraits<BDataType>::name
594-
<< " C_Type=" << DataTypeTraits<CDataType>::name
593+
<< " C_Layout=" << CLayout::name
594+
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
595+
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
596+
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
595597
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
596598
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
597599
<< tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl;

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -401,63 +401,6 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
401401
using CDataType = int32_t;
402402
};
403403

404-
template <typename T>
405-
struct DataTypeTraits;
406-
407-
template <>
408-
struct DataTypeTraits<float>
409-
{
410-
static constexpr const char* name = "fp32";
411-
};
412-
413-
template <>
414-
struct DataTypeTraits<double>
415-
{
416-
static constexpr const char* name = "fp64";
417-
};
418-
419-
template <>
420-
struct DataTypeTraits<int32_t>
421-
{
422-
static constexpr const char* name = "int32";
423-
};
424-
425-
template <>
426-
struct DataTypeTraits<ck_tile::half_t>
427-
{
428-
static constexpr const char* name = "fp16";
429-
};
430-
431-
template <>
432-
struct DataTypeTraits<ck_tile::bf16_t>
433-
{
434-
static constexpr const char* name = "bf16";
435-
};
436-
437-
template <>
438-
struct DataTypeTraits<ck_tile::fp8_t>
439-
{
440-
static constexpr const char* name = "fp8";
441-
};
442-
443-
template <>
444-
struct DataTypeTraits<ck_tile::bf8_t>
445-
{
446-
static constexpr const char* name = "bf8";
447-
};
448-
449-
template <>
450-
struct DataTypeTraits<ck_tile::pk_int4_t>
451-
{
452-
static constexpr const char* name = "pk_int4_t";
453-
};
454-
455-
template <>
456-
struct DataTypeTraits<ck_tile::int8_t>
457-
{
458-
static constexpr const char* name = "int8";
459-
};
460-
461404
template <ck_tile::GemmPipeline PipelineId>
462405
struct PipelineTypeTraits;
463406

example/ck_tile/03_gemm/run_gemm_example.inc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55
#include "ck_tile/host/permute_pk_int4.hpp"
66
#include "ck_tile/host/tensor_shuffle_utils.hpp"
7+
#include "ck_tile/ops/common/utils.hpp"
78

89
template <typename Layout>
910
static constexpr inline auto is_row_major(Layout layout_)
@@ -372,9 +373,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
372373
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
373374
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
374375
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
375-
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
376-
<< " B_Type=" << DataTypeTraits<BDataType>::name
377-
<< " C_Type=" << DataTypeTraits<CDataType>::name
376+
<< " C_Layout=" << CLayout::name
377+
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
378+
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
379+
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
378380
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
379381
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
380382
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -442,18 +444,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
442444
BDataType,
443445
CDataType,
444446
GemmConfig,
445-
DataTypeTraits>(arg_parser.get_str("jsonfile"),
446-
M,
447-
N,
448-
K,
449-
stride_A,
450-
stride_B,
451-
stride_C,
452-
persistent,
453-
pass,
454-
ave_time,
455-
tflops,
456-
gb_per_sec);
447+
ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"),
448+
M,
449+
N,
450+
K,
451+
stride_A,
452+
stride_B,
453+
stride_C,
454+
persistent,
455+
pass,
456+
ave_time,
457+
tflops,
458+
gb_per_sec);
457459
}
458460

459461
return pass;

example/ck_tile/05_reduce/reduce.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,6 @@
66
#include "ck_tile/utility/json_dump.hpp"
77
#include <cstring>
88

9-
template <typename T>
10-
struct DataTypeTraits;
11-
12-
template <>
13-
struct DataTypeTraits<ck_tile::half_t>
14-
{
15-
static constexpr const char* name = "fp16";
16-
};
17-
18-
template <>
19-
struct DataTypeTraits<ck_tile::bf16_t>
20-
{
21-
static constexpr const char* name = "bf16";
22-
};
23-
249
auto create_args(int argc, char* argv[])
2510
{
2611
ck_tile::ArgParser arg_parser;
@@ -145,7 +130,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
145130

146131
if(arg_parser.get_int("json") == 1)
147132
{
148-
dump_reduce_json_results<DataType, DataTypeTraits>(
133+
dump_reduce_json_results<DataType, ck_tile::DataTypeTraits>(
149134
arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec);
150135
}
151136

example/ck_tile/18_flatmm/flatmm_basic.hpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,38 +136,6 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
136136
using CDataType = ck_tile::half_t;
137137
};
138138

139-
template <typename T>
140-
struct DataTypeTraits;
141-
142-
template <>
143-
struct DataTypeTraits<ck_tile::fp8_t>
144-
{
145-
static constexpr const char* name = "fp8";
146-
};
147-
148-
template <>
149-
struct DataTypeTraits<ck_tile::bf8_t>
150-
{
151-
static constexpr const char* name = "bf8";
152-
};
153-
template <>
154-
struct DataTypeTraits<float>
155-
{
156-
static constexpr const char* name = "fp32";
157-
};
158-
159-
template <>
160-
struct DataTypeTraits<double>
161-
{
162-
static constexpr const char* name = "fp64";
163-
};
164-
165-
template <>
166-
struct DataTypeTraits<ck_tile::half_t>
167-
{
168-
static constexpr const char* name = "fp16";
169-
};
170-
171139
template <typename T>
172140
struct is_8bit_type
173141
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>

example/ck_tile/18_flatmm/moe_flatmm.hpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -134,38 +134,6 @@ struct GemmBasicTypeConfig<ck_tile::bf8_t>
134134
using CDataType = ck_tile::half_t;
135135
};
136136

137-
template <typename T>
138-
struct DataTypeTraits;
139-
140-
template <>
141-
struct DataTypeTraits<ck_tile::fp8_t>
142-
{
143-
static constexpr const char* name = "fp8";
144-
};
145-
146-
template <>
147-
struct DataTypeTraits<ck_tile::bf8_t>
148-
{
149-
static constexpr const char* name = "bf8";
150-
};
151-
template <>
152-
struct DataTypeTraits<float>
153-
{
154-
static constexpr const char* name = "fp32";
155-
};
156-
157-
template <>
158-
struct DataTypeTraits<double>
159-
{
160-
static constexpr const char* name = "fp64";
161-
};
162-
163-
template <>
164-
struct DataTypeTraits<ck_tile::half_t>
165-
{
166-
static constexpr const char* name = "fp16";
167-
};
168-
169137
template <typename T>
170138
struct is_8bit_type
171139
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>

example/ck_tile/20_grouped_convolution/conv_configs.hpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,6 @@ struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
254254
using OutDataType = ck_tile::bf16_t;
255255
};
256256

257-
template <typename T>
258-
struct DataTypeTraits;
259-
260-
template <>
261-
struct DataTypeTraits<float>
262-
{
263-
static constexpr const char* name = "fp32";
264-
};
265-
266-
template <>
267-
struct DataTypeTraits<ck_tile::half_t>
268-
{
269-
static constexpr const char* name = "fp16";
270-
};
271-
272-
template <>
273-
struct DataTypeTraits<ck_tile::bf16_t>
274-
{
275-
static constexpr const char* name = "bf16";
276-
};
277-
278257
template <ck_tile::GemmPipeline PipelineId>
279258
struct PipelineTypeTraits;
280259

example/ck_tile/38_block_scale_gemm/gemm_utils.hpp

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -281,59 +281,36 @@ struct GemmQuantTypeConfig
281281
using CDataType = CDataType_;
282282
};
283283

284-
template <typename T>
285-
struct DataTypeTraits;
286-
287-
template <>
288-
struct DataTypeTraits<float>
289-
{
290-
static constexpr const char* name = "fp32";
291-
};
292-
293-
template <>
294-
struct DataTypeTraits<double>
295-
{
296-
static constexpr const char* name = "fp64";
297-
};
298-
299-
template <>
300-
struct DataTypeTraits<int32_t>
284+
auto create_args(int argc, char* argv[])
301285
{
302-
static constexpr const char* name = "int32";
303-
};
304-
305-
template <>
306-
struct DataTypeTraits<ck_tile::half_t>
307-
{
308-
static constexpr const char* name = "fp16";
309-
};
310-
311-
template <>
312-
struct DataTypeTraits<ck_tile::bf16_t>
313-
{
314-
static constexpr const char* name = "bf16";
315-
};
316-
317-
template <>
318-
struct DataTypeTraits<ck_tile::fp8_t>
319-
{
320-
static constexpr const char* name = "fp8";
321-
};
322-
323-
template <>
324-
struct DataTypeTraits<ck_tile::bf8_t>
325-
{
326-
static constexpr const char* name = "bf8";
327-
};
328-
329-
template <>
330-
struct DataTypeTraits<ck_tile::pk_int4_t>
331-
{
332-
static constexpr const char* name = "pk_int4_t";
333-
};
334-
335-
template <>
336-
struct DataTypeTraits<ck_tile::int8_t>
337-
{
338-
static constexpr const char* name = "int8";
339-
};
286+
ck_tile::ArgParser arg_parser;
287+
arg_parser.insert("m", "3840", "m dimension")
288+
.insert("n", "4096", "n dimension")
289+
.insert("k", "2048", "k dimension")
290+
.insert("a_layout", "R", "A tensor data layout - Row by default")
291+
.insert("b_layout", "C", "B tensor data layout - Column by default")
292+
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
293+
.insert("c_layout", "R", "C tensor data layout - Row by default")
294+
.insert("stride_a", "0", "Tensor A stride")
295+
.insert("stride_q", "0", "Tensor AQ stride")
296+
.insert("stride_b", "0", "Tensor B stride")
297+
.insert("stride_c", "0", "Tensor C stride")
298+
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
299+
.insert("prec",
300+
"fp8",
301+
"data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4")
302+
.insert("warmup", "50", "number of iterations before benchmark the kernel")
303+
.insert("repeat", "1000", "number of iterations to benchmark the kernel")
304+
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
305+
.insert("split_k", "1", "splitK value")
306+
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
307+
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
308+
.insert("rotating_count", "1000", "rotating count, defaults to 1")
309+
.insert("quant_mode", "bquant", "Choose aquant (default), bquant, tensor or rowcol")
310+
.insert("group_size",
311+
"1x1x128",
312+
"Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128");
313+
314+
bool result = arg_parser.parse(argc, argv);
315+
return std::make_tuple(result, arg_parser);
316+
}

example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <tuple>
1212

1313
#include "ck_tile/core/config.hpp"
14+
#include "ck_tile/ops/common/utils.hpp"
1415
#include "ck_tile/host.hpp"
1516
#include "ck_tile/host/permute_pk_int4.hpp"
1617
#include "ck_tile/host/tensor_shuffle_utils.hpp"
@@ -321,15 +322,15 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
321322
{
322323
std::cout << " StrideBQ =" << stride_BQ;
323324
}
324-
std::cout << " A_Type = " << DataTypeTraits<typename TypeConfig::ADataType>::name
325-
<< " AQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name
326-
<< " B_Type = " << DataTypeTraits<typename TypeConfig::BDataType>::name;
325+
std::cout << " A_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::ADataType>::name
326+
<< " AQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name
327+
<< " B_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::BDataType>::name;
327328
if constexpr(!std::is_same_v<typename TypeConfig::QDataType, void>)
328329
{
329-
std::cout << " BQ_Type = " << DataTypeTraits<typename TypeConfig::QDataType>::name;
330+
std::cout << " BQ_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::QDataType>::name;
330331
}
331-
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
332-
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
332+
std::cout << " Acc_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::AccDataType>::name
333+
<< " C_Type = " << ck_tile::DataTypeTraits<typename TypeConfig::CDataType>::name
333334
<< " QuantMode = " << quant_type_to_string(QuantMode)
334335
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
335336
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "

0 commit comments

Comments
 (0)