|
| 1 | +#ifndef INFINI_OPS_NINETOOTHED_RMS_NORM_H_ |
| 2 | +#define INFINI_OPS_NINETOOTHED_RMS_NORM_H_ |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <cstdint> |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +#include "base/rms_norm.h" |
| 9 | +#include "data_type.h" |
| 10 | +#include "ninetoothed/tensor.h" |
| 11 | +#include "rms_norm/infini_ops_ninetoothed_rms_norm.h" |
| 12 | + |
| 13 | +namespace infini::ops { |
| 14 | + |
| 15 | +template <> |
| 16 | +class Operator<RmsNorm, Device::Type::kNvidia, 9> : public RmsNorm { |
| 17 | + public: |
| 18 | + using RmsNorm::RmsNorm; |
| 19 | + using RmsNorm::operator(); |
| 20 | + |
| 21 | + void operator()(const Tensor input, const Tensor weight, float eps, |
| 22 | + Tensor out) const override { |
| 23 | + assert(input.dtype() == out.dtype() && out.dtype() == weight.dtype() && |
| 24 | + "operator `RmsNorm` requires all input and output tensors to have " |
| 25 | + "the same dtype"); |
| 26 | + assert(input.shape() == out.shape() && |
| 27 | + "NineToothed `RmsNorm` requires input and output tensors with the " |
| 28 | + "same shape"); |
| 29 | + assert(weight.ndim() == 1 && weight.size(-1) == out.size(-1) && |
| 30 | + "NineToothed `RmsNorm` requires a 1D weight matching the last " |
| 31 | + "dimension"); |
| 32 | + assert( |
| 33 | + (out.ndim() == 2 || out.ndim() == 3 || out.ndim() == 4) && |
| 34 | + "NineToothed `RmsNorm` currently supports rank-2, rank-3, and rank-4 " |
| 35 | + "tensors"); |
| 36 | + |
| 37 | + std::vector<std::uint64_t> weight_sizes; |
| 38 | + std::vector<std::int64_t> weight_strides; |
| 39 | + double eps_value = static_cast<double>(eps); |
| 40 | + std::int64_t num_normalized_elements = |
| 41 | + static_cast<std::int64_t>(out.size(-1)); |
| 42 | + std::uint64_t empty_shape[1] = {}; |
| 43 | + std::int64_t empty_strides[1] = {}; |
| 44 | + |
| 45 | + weight_sizes.assign(out.shape().begin(), out.shape().end()); |
| 46 | + weight_strides.assign(out.ndim(), 0); |
| 47 | + weight_strides.back() = |
| 48 | + weight.strides().empty() ? 1 : weight.strides().back(); |
| 49 | + |
| 50 | + const int dtype_index = ninetoothed::DataTypeIndex(out.dtype()); |
| 51 | + assert( |
| 52 | + dtype_index >= 0 && |
| 53 | + "NineToothed `RmsNorm` supports only float16, bfloat16, and float32"); |
| 54 | + |
| 55 | + ninetoothed::Tensor input_tensor(input); |
| 56 | + ninetoothed::Tensor weight_tensor(const_cast<void*>(weight.data()), |
| 57 | + weight_sizes.data(), |
| 58 | + weight_strides.data()); |
| 59 | + ninetoothed::Tensor eps_tensor(eps_value, empty_shape, empty_strides); |
| 60 | + ninetoothed::Tensor out_tensor(out); |
| 61 | + ninetoothed::Tensor num_normalized_elements_tensor( |
| 62 | + num_normalized_elements, empty_shape, empty_strides); |
| 63 | + |
| 64 | + auto result = launch_infini_ops_ninetoothed_rms_norm( |
| 65 | + static_cast<NineToothedStream>(stream_), input_tensor, weight_tensor, |
| 66 | + eps_tensor, out_tensor, num_normalized_elements_tensor, |
| 67 | + static_cast<int>(out.ndim()), 1, dtype_index, dtype_index, dtype_index); |
| 68 | + |
| 69 | + assert(result == 0 && "NineToothed `RmsNorm` launch failed"); |
| 70 | + } |
| 71 | +}; |
| 72 | + |
| 73 | +} // namespace infini::ops |
| 74 | + |
| 75 | +#endif |
0 commit comments