11#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_
22#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_
33
4- #include < cuda_bf16.h>
5- #include < cuda_fp16.h>
6-
74#include < cstddef>
85#include < cstdint>
96#include < cub/block/block_reduce.cuh>
107
8+ #include " common/cuda/cast.h"
9+ #include " common/cuda/kernel_commons.h"
10+
1111namespace infini ::ops {
1212
1313namespace {
1414
15- template <unsigned int block_size, typename Data , typename Compute >
16- __device__ __forceinline__ Compute SumSquared (const Data * data_ptr,
17- size_t count) {
18- Compute ss = 0 ;
15+ template <unsigned int block_size, typename TData , typename TCompute >
16+ __device__ __forceinline__ TCompute SumSquared (const TData * data_ptr,
17+ size_t count) {
18+ TCompute ss = 0 ;
1919 for (size_t i = threadIdx .x ; i < count; i += block_size) {
20- Compute val = Compute (data_ptr[i]);
21- ss += val * val ;
20+ TCompute value = Cast<TCompute> (data_ptr[i]);
21+ ss += value * value ;
2222 }
23- using BlockReduce = cub::BlockReduce<Compute , block_size>;
23+ using BlockReduce = cub::BlockReduce<TCompute , block_size>;
2424 __shared__ typename BlockReduce::TempStorage temp_storage;
2525 return BlockReduce (temp_storage).Sum (ss);
2626}
2727
2828} // namespace
2929
30- template <unsigned int block_size, typename Compute , typename Data ,
31- typename Weight >
32- __global__ void RmsNormKernel (Data * __restrict__ y, int64_t stride_y_batch,
30+ template <unsigned int block_size, typename TCompute , typename TData ,
31+ typename TWeight >
32+ __global__ void RmsNormKernel (TData * __restrict__ y, int64_t stride_y_batch,
3333 int64_t stride_y_nhead,
34- const Data * __restrict__ x,
34+ const TData * __restrict__ x,
3535 int64_t stride_x_batch, int64_t stride_x_nhead,
36- const Weight * __restrict__ w, size_t nhead,
36+ const TWeight * __restrict__ w, size_t nhead,
3737 size_t dim, float epsilon) {
3838 size_t batch_idx = blockIdx .x / nhead;
3939 size_t head_idx = blockIdx .x % nhead;
@@ -42,16 +42,17 @@ __global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch,
4242 auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead;
4343 auto w_ptr = w;
4444
45- Compute ss = SumSquared<block_size, Data, Compute >(x_ptr, dim);
45+ TCompute ss = SumSquared<block_size, TData, TCompute >(x_ptr, dim);
4646
47- __shared__ Compute rms;
47+ __shared__ TCompute rms;
4848 if (threadIdx .x == 0 ) {
49- rms = Compute (rsqrtf (ss / Compute (dim) + epsilon));
49+ rms = Cast<TCompute> (rsqrtf (ss / Cast<TCompute> (dim) + epsilon));
5050 }
5151 __syncthreads ();
5252
5353 for (size_t i = threadIdx .x ; i < dim; i += block_size) {
54- y_ptr[i] = Data (Compute (x_ptr[i]) * Compute (w_ptr[i]) * rms);
54+ y_ptr[i] =
55+ Cast<TData>(Cast<TCompute>(x_ptr[i]) * Cast<TCompute>(w_ptr[i]) * rms);
5556 }
5657}
5758
0 commit comments