@@ -43,36 +43,13 @@ class CudaRmsNorm : public RmsNorm {
4343 using T = TypeMapType<ListGet<0 >(list_tag)>;
4444 constexpr int kBlockSize = ListGet<1 >(list_tag);
4545
46- <<<<<<< HEAD
47- #define LAUNCH_RMS_NORM_KERNEL (BLOCK_SIZE ) \
48- RmsNormKernel<BLOCK_SIZE, float , T, T> \
49- <<<num_blocks, BLOCK_SIZE, 0 , cuda_stream>>>( \
50- reinterpret_cast <T*>(out.data ()), stride_out_batch, \
51- stride_out_nhead, reinterpret_cast <const T*>(input.data ()), \
52- stride_input_batch, stride_input_nhead, \
53- reinterpret_cast <const T*>(weight.data ()), nhead_, dim_, eps);
54-
55- if (block_size == CUDA_BLOCK_SIZE_2048) {
56- LAUNCH_RMS_NORM_KERNEL (CUDA_BLOCK_SIZE_2048)
57- } else if (block_size == CUDA_BLOCK_SIZE_1024) {
58- LAUNCH_RMS_NORM_KERNEL (CUDA_BLOCK_SIZE_1024)
59- } else if (block_size == CUDA_BLOCK_SIZE_512) {
60- LAUNCH_RMS_NORM_KERNEL (CUDA_BLOCK_SIZE_512)
61- } else if (block_size == CUDA_BLOCK_SIZE_256) {
62- LAUNCH_RMS_NORM_KERNEL (CUDA_BLOCK_SIZE_256)
63- } else {
64- LAUNCH_RMS_NORM_KERNEL (CUDA_BLOCK_SIZE_128)
65- }
66-
67- #undef LAUNCH_RMS_NORM_KERNEL
68- == == == = RmsNormKernel<kBlockSize , float , T, T>
46+ RmsNormKernel<kBlockSize , float , T, T>
6947 <<<num_blocks, kBlockSize , 0 , cuda_stream>>>(
7048 reinterpret_cast <T*>(out.data ()), stride_out_batch,
7149 stride_out_nhead, reinterpret_cast <const T*>(input.data ()),
7250 stride_input_batch, stride_input_nhead,
7351 reinterpret_cast <const T*>(weight.data ()), nhead_, dim_,
7452 eps_);
75- >>>>>>> ae94669 (feat: add a convenient interface for any `int64_t `-convertible types and use `DispatchFunc ()` to dispatch `DataType` and block sizes with a single call.)
7653 },
7754 " CudaRmsNorm::operator()" );
7855 }
0 commit comments