Skip to content

Commit 7dcfea6

Browse files
committed
fix: fix rebase error
1 parent 3e9d050 commit 7dcfea6

File tree

1 file changed

+1
-24
lines changed

1 file changed

+1
-24
lines changed

src/cuda/rms_norm/kernel.h

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,36 +43,13 @@ class CudaRmsNorm : public RmsNorm {
4343
using T = TypeMapType<ListGet<0>(list_tag)>;
4444
constexpr int kBlockSize = ListGet<1>(list_tag);
4545

46-
<<<<<<< HEAD
47-
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
48-
RmsNormKernel<BLOCK_SIZE, float, T, T> \
49-
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
50-
reinterpret_cast<T*>(out.data()), stride_out_batch, \
51-
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
52-
stride_input_batch, stride_input_nhead, \
53-
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
54-
55-
if (block_size == CUDA_BLOCK_SIZE_2048) {
56-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)
57-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
58-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024)
59-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
60-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512)
61-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
62-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256)
63-
} else {
64-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128)
65-
}
66-
67-
#undef LAUNCH_RMS_NORM_KERNEL
68-
== == == = RmsNormKernel<kBlockSize, float, T, T>
46+
RmsNormKernel<kBlockSize, float, T, T>
6947
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
7048
reinterpret_cast<T*>(out.data()), stride_out_batch,
7149
stride_out_nhead, reinterpret_cast<const T*>(input.data()),
7250
stride_input_batch, stride_input_nhead,
7351
reinterpret_cast<const T*>(weight.data()), nhead_, dim_,
7452
eps_);
75-
>>>>>>> ae94669 (feat: add a convenient interface for any `int64_t`-convertible types and use `DispatchFunc()` to dispatch `DataType` and block sizes with a single call.)
7653
},
7754
"CudaRmsNorm::operator()");
7855
}

0 commit comments

Comments
 (0)