Skip to content

Commit f17e37c

Browse files
authored
feat: add high-level DispatchFunc() interface for multi-type and mixed dispatch (#29)
* feat: add a convenient interface for any `int64_t`-convertible types and use `DispatchFunc()` to dispatch `DataType` and block sizes with a single call. - add a convenient interface for any `int64_t`-convertible types, which is mostly used for multi-type dispatch and mixed dispatch - use `DispatchFunc()` to dispatch `DataType` and block sizes with a single function call in various kernels' implementation - remove the `CUDA_BLOCK_SIZE_XXX` macros and simply use numbers instead * style: fix the styling issue by adding a period to the TODO comment * fix: fix rebase error * style: fix the styling issues for comments in `dispatcher.h` and `cuda/causal_softmax/kernel.h`
1 parent 8c92b2e commit f17e37c

File tree

7 files changed

+74
-116
lines changed

7 files changed

+74
-116
lines changed

src/cuda/add/kernel.h

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ class CudaAdd : public Add {
5151
void operator()(const Tensor input, const Tensor other,
5252
Tensor out) const override {
5353
int block_size = GetOptimalBlockSize();
54-
DispatchFunc<AllTypes>(
55-
out_type_,
56-
[&](auto tag) {
57-
using T = typename decltype(tag)::type;
54+
DispatchFunc<AllTypes, AllCudaBlockSizes>(
55+
{static_cast<int64_t>(out_type_), block_size},
56+
[&](auto list_tag) {
57+
using T = TypeMapType<ListGet<0>(list_tag)>;
58+
constexpr int kBlockSize = ListGet<1>(list_tag);
59+
5860
auto cuda_stream =
5961
static_cast<typename Backend::stream_t>(stream_ ? stream_ : 0);
6062
dim3 blockDims(
@@ -65,25 +67,11 @@ class CudaAdd : public Add {
6567
const T* d_input = reinterpret_cast<const T*>(input.data());
6668
const T* d_other = reinterpret_cast<const T*>(other.data());
6769

68-
#define LAUNCH_ADD_KERNEL(BLOCK_SIZE) \
69-
AddKernel<T, BLOCK_SIZE><<<gridDims, blockDims, 0, cuda_stream>>>( \
70-
d_out, d_input, d_other, d_out_shape_, d_input_shape_, d_other_shape_, \
71-
d_out_strides_, d_input_strides_, d_other_strides_, output_size_, ndim_, \
72-
is_out_contiguous_, is_input_contiguous_, is_other_contiguous_);
73-
74-
if (block_size == CUDA_BLOCK_SIZE_2048) {
75-
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_2048)
76-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
77-
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_1024)
78-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
79-
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_512)
80-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
81-
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_256)
82-
} else {
83-
LAUNCH_ADD_KERNEL(CUDA_BLOCK_SIZE_128)
84-
}
85-
86-
#undef LAUNCH_ADD_KERNEL
70+
AddKernel<T, kBlockSize><<<gridDims, blockDims, 0, cuda_stream>>>(
71+
d_out, d_input, d_other, d_out_shape_, d_input_shape_,
72+
d_other_shape_, d_out_strides_, d_input_strides_,
73+
d_other_strides_, output_size_, ndim_, is_out_contiguous_,
74+
is_input_contiguous_, is_other_contiguous_);
8775
},
8876
"CudaAdd::operator()");
8977
}

src/cuda/causal_softmax/kernel.h

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,20 @@ class CudaCausalSoftmax : public CausalSoftmax {
3434

3535
int block_size = GetOptimalBlockSize();
3636

37-
DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
38-
out.dtype(),
39-
[&](auto tag) {
40-
using T = typename decltype(tag)::type;
41-
42-
#define LAUNCH_CAUSAL_SOFTMAX_KERNEL(BLOCK_SIZE) \
43-
CausalSoftmaxKernel<BLOCK_SIZE, T, float> \
44-
<<<grid, BLOCK_SIZE, 0, cuda_stream>>>( \
45-
reinterpret_cast<T*>(out.data()), \
46-
reinterpret_cast<const T*>(input.data()), batch_size_, seq_len_, \
47-
total_seq_len_, stride_out_batch, stride_out_row, \
48-
stride_input_batch, stride_input_row);
49-
50-
if (block_size == CUDA_BLOCK_SIZE_2048) {
51-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_2048)
52-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
53-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_1024)
54-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
55-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_512)
56-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
57-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_256)
58-
} else {
59-
LAUNCH_CAUSAL_SOFTMAX_KERNEL(CUDA_BLOCK_SIZE_128)
60-
}
61-
62-
#undef LAUNCH_CAUSAL_SOFTMAX_KERNEL
37+
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
38+
AllCudaBlockSizes>(
39+
// TODO: Output dtype should use the one passed in during construction.
40+
{static_cast<int64_t>(out.dtype()), block_size},
41+
[&](auto list_tag) {
42+
using T = TypeMapType<ListGet<0>(list_tag)>;
43+
constexpr int kBlockSize = ListGet<1>(list_tag);
44+
45+
CausalSoftmaxKernel<kBlockSize, T, float>
46+
<<<grid, kBlockSize, 0, cuda_stream>>>(
47+
reinterpret_cast<T*>(out.data()),
48+
reinterpret_cast<const T*>(input.data()), batch_size_,
49+
seq_len_, total_seq_len_, stride_out_batch, stride_out_row,
50+
stride_input_batch, stride_input_row);
6351
},
6452
"CudaCausalSoftmax::operator()");
6553
}

src/cuda/kernel_commons.h

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@ using cuda_bfloat162 = __mt_bfloat162;
3333

3434
namespace infini::ops {
3535

36-
constexpr int CUDA_BLOCK_SIZE_128 = 128;
37-
constexpr int CUDA_BLOCK_SIZE_256 = 256;
38-
constexpr int CUDA_BLOCK_SIZE_512 = 512;
39-
constexpr int CUDA_BLOCK_SIZE_1024 = 1024;
40-
constexpr int CUDA_BLOCK_SIZE_2048 = 2048;
36+
using AllCudaBlockSizes = List<128, 256, 512, 1024, 2048>;
4137

4238
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
4339
// Cache `cudaDeviceProp` per device, initialized once at first access.
@@ -76,7 +72,7 @@ inline int QueryMaxThreadsPerBlock() {
7672
#elif defined(WITH_METAX)
7773
inline int QueryMaxThreadsPerBlock() {
7874
// TODO: Add MCR device properties query for Metax.
79-
return CUDA_BLOCK_SIZE_256;
75+
return 256;
8076
}
8177
#elif defined(WITH_MOORE)
8278
inline int QueryMaxThreadsPerBlock() {
@@ -91,16 +87,16 @@ inline int QueryMaxThreadsPerBlock() {
9187
// Get optimal block size based on GPU hardware architecture.
9288
inline int GetOptimalBlockSize() {
9389
int max_threads = QueryMaxThreadsPerBlock();
94-
if (max_threads >= CUDA_BLOCK_SIZE_2048) {
95-
return CUDA_BLOCK_SIZE_2048;
96-
} else if (max_threads >= CUDA_BLOCK_SIZE_1024) {
97-
return CUDA_BLOCK_SIZE_1024;
98-
} else if (max_threads >= CUDA_BLOCK_SIZE_512) {
99-
return CUDA_BLOCK_SIZE_512;
100-
} else if (max_threads >= CUDA_BLOCK_SIZE_256) {
101-
return CUDA_BLOCK_SIZE_256;
90+
if (max_threads >= 2048) {
91+
return 2048;
92+
} else if (max_threads >= 1024) {
93+
return 1024;
94+
} else if (max_threads >= 512) {
95+
return 512;
96+
} else if (max_threads >= 256) {
97+
return 256;
10298
} else {
103-
return CUDA_BLOCK_SIZE_128;
99+
return 128;
104100
}
105101
}
106102

src/cuda/rms_norm/kernel.h

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,20 @@ class CudaRmsNorm : public RmsNorm {
3636

3737
int block_size = GetOptimalBlockSize();
3838

39-
DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
40-
out.dtype(),
41-
[&](auto tag) {
42-
using T = typename decltype(tag)::type;
43-
44-
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
45-
RmsNormKernel<BLOCK_SIZE, float, T, T> \
46-
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
47-
reinterpret_cast<T*>(out.data()), stride_out_batch, \
48-
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
49-
stride_input_batch, stride_input_nhead, \
50-
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
51-
52-
if (block_size == CUDA_BLOCK_SIZE_2048) {
53-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)
54-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
55-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024)
56-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
57-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512)
58-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
59-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256)
60-
} else {
61-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128)
62-
}
63-
64-
#undef LAUNCH_RMS_NORM_KERNEL
39+
DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
40+
AllCudaBlockSizes>(
41+
{static_cast<int64_t>(out.dtype()), block_size},
42+
[&](auto list_tag) {
43+
using T = TypeMapType<ListGet<0>(list_tag)>;
44+
constexpr int kBlockSize = ListGet<1>(list_tag);
45+
46+
RmsNormKernel<kBlockSize, float, T, T>
47+
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
48+
reinterpret_cast<T*>(out.data()), stride_out_batch,
49+
stride_out_nhead, reinterpret_cast<const T*>(input.data()),
50+
stride_input_batch, stride_input_nhead,
51+
reinterpret_cast<const T*>(weight.data()), nhead_, dim_,
52+
eps_);
6553
},
6654
"CudaRmsNorm::operator()");
6755
}

src/cuda/swiglu/kernel.h

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ class CudaSwiglu : public Swiglu {
5050
void operator()(const Tensor input, const Tensor gate,
5151
Tensor out) const override {
5252
int block_size = GetOptimalBlockSize();
53-
DispatchFunc<AllFloatTypes>(
54-
out_type_,
55-
[&](auto tag) {
56-
using T = typename decltype(tag)::type;
53+
DispatchFunc<AllFloatTypes, AllCudaBlockSizes>(
54+
{static_cast<int64_t>(out_type_), block_size},
55+
[&](auto list_tag) {
56+
using T = TypeMapType<ListGet<0>(list_tag)>;
57+
constexpr int kBlockSize = ListGet<1>(list_tag);
58+
5759
auto cuda_stream =
5860
static_cast<typename Backend::stream_t>(stream_ ? stream_ : 0);
5961
dim3 blockDims(
@@ -64,25 +66,11 @@ class CudaSwiglu : public Swiglu {
6466
const T* d_input = reinterpret_cast<const T*>(input.data());
6567
const T* d_gate = reinterpret_cast<const T*>(gate.data());
6668

67-
// Launch kernel with appropriate block size based on GPU architecture.
68-
#define LAUNCH_SWIGLU_KERNEL(BLOCK_SIZE) \
69-
SwigluKernel<T, BLOCK_SIZE><<<gridDims, blockDims, 0, cuda_stream>>>( \
70-
d_out, d_input, d_gate, d_out_shape_, d_input_shape_, d_gate_shape_, \
71-
d_out_strides_, d_input_strides_, d_gate_strides_, output_size_, ndim_, \
72-
is_out_contiguous_, is_input_contiguous_, is_gate_contiguous_);
73-
if (block_size == CUDA_BLOCK_SIZE_2048) {
74-
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_2048)
75-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
76-
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_1024)
77-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
78-
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_512)
79-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
80-
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_256)
81-
} else {
82-
LAUNCH_SWIGLU_KERNEL(CUDA_BLOCK_SIZE_128)
83-
}
84-
85-
#undef LAUNCH_SWIGLU_KERNEL
69+
SwigluKernel<T, kBlockSize><<<gridDims, blockDims, 0, cuda_stream>>>(
70+
d_out, d_input, d_gate, d_out_shape_, d_input_shape_,
71+
d_gate_shape_, d_out_strides_, d_input_strides_, d_gate_strides_,
72+
output_size_, ndim_, is_out_contiguous_, is_input_contiguous_,
73+
is_gate_contiguous_);
8674
},
8775
"CudaSwiglu::operator()");
8876
}

src/dispatcher.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,16 @@ auto DispatchFunc(ValueType value, Functor &&func,
302302
std::forward<Args>(args)...);
303303
}
304304

305+
// Interface for Any `int64_t`-Convertible Types
306+
template <typename... Lists, typename Functor, typename... Args>
307+
auto DispatchFunc(std::initializer_list<int64_t> keys, Functor &&func,
308+
std::string_view context_str = "", Args &&...args) {
309+
std::vector<int64_t> v_keys(keys);
310+
return DispatchFunc<Lists...>(v_keys, 0, std::forward<Functor>(func),
311+
context_str, List<>{},
312+
std::forward<Args>(args)...);
313+
}
314+
305315
} // namespace infini::ops
306316

307317
#endif

src/operator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ class Operator : public OperatorBase {
103103
DispatchFunc<ActiveDevices>(
104104
tensor.device().type(),
105105
[&](auto tag) {
106-
constexpr Device::Type dev = decltype(tag)::value;
107-
if constexpr (std::is_constructible_v<Operator<Key, dev>,
106+
constexpr Device::Type kDev = decltype(tag)::value;
107+
if constexpr (std::is_constructible_v<Operator<Key, kDev>,
108108
const Tensor&, Args...>) {
109-
op_ptr = std::make_unique<Operator<Key, dev>>(
109+
op_ptr = std::make_unique<Operator<Key, kDev>>(
110110
tensor, std::forward<Args>(args)...);
111111
} else {
112112
assert(false && "operator is not implemented for this device");

0 commit comments

Comments
 (0)