Skip to content

Commit f44be6f

Browse files
gongchensugongchensu
andauthored
feat(ops): add MetaX backend for RmsNorm (#25)
- add MetaX `RmsNorm` operator specialization - make the shared CUDA-style rms_norm kernel compatible with MetaX - forward runtime `eps` when launching the kernel Co-authored-by: gongchensu <zhuyue@qiyuanlab.com>
1 parent 1b0b5ac commit f44be6f

File tree

4 files changed

+59
-30
lines changed

4 files changed

+59
-30
lines changed

src/base/rms_norm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class RmsNorm : public Operator<RmsNorm> {
2525
RmsNorm(const Tensor input, const Tensor weight, Tensor out)
2626
: RmsNorm{input, weight, 1e-6f, out} {}
2727

28+
// TODO: Type of `eps` should be `std::optional<float>` instead of `float`.
2829
virtual void operator()(const Tensor input, const Tensor weight, float eps,
2930
Tensor out) const = 0;
3031

src/cuda/rms_norm/kernel.cuh

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
11
#ifndef INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_
22
#define INFINI_OPS_CUDA_RMS_NORM_KERNEL_CUH_
33

4-
#include <cuda_bf16.h>
5-
#include <cuda_fp16.h>
6-
74
#include <cstddef>
85
#include <cstdint>
96
#include <cub/block/block_reduce.cuh>
107

8+
#include "common/cuda/cast.h"
9+
#include "common/cuda/kernel_commons.h"
10+
1111
namespace infini::ops {
1212

1313
namespace {
1414

15-
template <unsigned int block_size, typename Data, typename Compute>
16-
__device__ __forceinline__ Compute SumSquared(const Data* data_ptr,
17-
size_t count) {
18-
Compute ss = 0;
15+
template <unsigned int block_size, typename TData, typename TCompute>
16+
__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr,
17+
size_t count) {
18+
TCompute ss = 0;
1919
for (size_t i = threadIdx.x; i < count; i += block_size) {
20-
Compute val = Compute(data_ptr[i]);
21-
ss += val * val;
20+
TCompute value = Cast<TCompute>(data_ptr[i]);
21+
ss += value * value;
2222
}
23-
using BlockReduce = cub::BlockReduce<Compute, block_size>;
23+
using BlockReduce = cub::BlockReduce<TCompute, block_size>;
2424
__shared__ typename BlockReduce::TempStorage temp_storage;
2525
return BlockReduce(temp_storage).Sum(ss);
2626
}
2727

2828
} // namespace
2929

30-
template <unsigned int block_size, typename Compute, typename Data,
31-
typename Weight>
32-
__global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch,
30+
template <unsigned int block_size, typename TCompute, typename TData,
31+
typename TWeight>
32+
__global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch,
3333
int64_t stride_y_nhead,
34-
const Data* __restrict__ x,
34+
const TData* __restrict__ x,
3535
int64_t stride_x_batch, int64_t stride_x_nhead,
36-
const Weight* __restrict__ w, size_t nhead,
36+
const TWeight* __restrict__ w, size_t nhead,
3737
size_t dim, float epsilon) {
3838
size_t batch_idx = blockIdx.x / nhead;
3939
size_t head_idx = blockIdx.x % nhead;
@@ -42,16 +42,17 @@ __global__ void RmsNormKernel(Data* __restrict__ y, int64_t stride_y_batch,
4242
auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead;
4343
auto w_ptr = w;
4444

45-
Compute ss = SumSquared<block_size, Data, Compute>(x_ptr, dim);
45+
TCompute ss = SumSquared<block_size, TData, TCompute>(x_ptr, dim);
4646

47-
__shared__ Compute rms;
47+
__shared__ TCompute rms;
4848
if (threadIdx.x == 0) {
49-
rms = Compute(rsqrtf(ss / Compute(dim) + epsilon));
49+
rms = Cast<TCompute>(rsqrtf(ss / Cast<TCompute>(dim) + epsilon));
5050
}
5151
__syncthreads();
5252

5353
for (size_t i = threadIdx.x; i < dim; i += block_size) {
54-
y_ptr[i] = Data(Compute(x_ptr[i]) * Compute(w_ptr[i]) * rms);
54+
y_ptr[i] =
55+
Cast<TData>(Cast<TCompute>(x_ptr[i]) * Cast<TCompute>(w_ptr[i]) * rms);
5556
}
5657
}
5758

src/cuda/rms_norm/kernel.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33

44
#include <cstdint>
55

6-
// clang-format off
7-
#include <cuda_runtime.h> // TODO: Remove this
8-
// clang-format on
9-
106
#include "base/rms_norm.h"
117
#include "common/cuda/kernel_commons.h"
128
#include "cuda/rms_norm/kernel.cuh"
@@ -45,13 +41,13 @@ class CudaRmsNorm : public RmsNorm {
4541
[&](auto tag) {
4642
using T = typename decltype(tag)::type;
4743

48-
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
49-
RmsNormKernel<BLOCK_SIZE, float, T, T> \
50-
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
51-
reinterpret_cast<T*>(out.data()), stride_out_batch, \
52-
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
53-
stride_input_batch, stride_input_nhead, \
54-
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps_);
44+
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
45+
RmsNormKernel<BLOCK_SIZE, float, T, T> \
46+
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
47+
reinterpret_cast<T*>(out.data()), stride_out_batch, \
48+
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
49+
stride_input_batch, stride_input_nhead, \
50+
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
5551

5652
if (block_size == CUDA_BLOCK_SIZE_2048) {
5753
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)

src/metax/rms_norm/kernel.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef INFINI_OPS_METAX_RMS_NORM_KERNEL_H_
2+
#define INFINI_OPS_METAX_RMS_NORM_KERNEL_H_
3+
4+
#include <utility>
5+
6+
// clang-format off
7+
#include <mcr/mc_runtime.h>
8+
// clang-format on
9+
10+
#include "cuda/rms_norm/kernel.h"
11+
12+
namespace infini::ops {
13+
14+
namespace rms_norm {
15+
16+
struct MetaxBackend {
17+
using stream_t = mcStream_t;
18+
};
19+
20+
} // namespace rms_norm
21+
22+
template <>
23+
class Operator<RmsNorm, Device::Type::kMetax>
24+
: public CudaRmsNorm<rms_norm::MetaxBackend> {
25+
public:
26+
using CudaRmsNorm<rms_norm::MetaxBackend>::CudaRmsNorm;
27+
};
28+
29+
} // namespace infini::ops
30+
31+
#endif

0 commit comments

Comments
 (0)