diff --git a/include/infinicore/nn/layer_norm.hpp b/include/infinicore/nn/layer_norm.hpp new file mode 100644 index 000000000..d69d0594f --- /dev/null +++ b/include/infinicore/nn/layer_norm.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "../ops.hpp" +#include "module.hpp" + +namespace infinicore::nn { + +class LayerNorm : public Module { +public: + /** + * @brief Construct a LayerNorm layer + * + * @param normalized_shape Size of the feature dimension to normalize (typically hidden_size) + * @param eps Small constant for numerical stability (default: 1e-6) + * @param dtype Data type for the weight (default: DataType::F32) + * @param device Device to create the weight on + */ + LayerNorm(size_t normalized_shape, + double eps = 1e-6, + const DataType &dtype = DataType::F32, + const Device &device = Device()); + + /** + * @brief Forward pass: apply LayerNorm + * + * @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions + * @return Normalized tensor with same shape as input + * + * The normalization is applied over the last dimension. + * For example: + * Input: [batch, seq_len, hidden_size] -> normalize over hidden_size + * Input: [batch, hidden_size] -> normalize over hidden_size + */ + Tensor forward(const Tensor &x) const; + + // Module information + size_t normalized_shape() const { return normalized_shape_; } + double eps() const { return eps_; } + DataType dtype() const { return dtype_; } + + // String representation + std::string extra_repr() const; + + // Accessors for parameters + Tensor weight() const { return weight_; } + Tensor bias() const { return bias_; } + +protected: + // Parameters + INFINICORE_NN_PARAMETER(weight); + INFINICORE_NN_PARAMETER(bias); + +private: + size_t normalized_shape_; // Size of the feature dimension + double eps_; // Epsilon for numerical stability + DataType dtype_; // Data type for weight +}; + +} // namespace infinicore::nn diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 18741c402..f021be5e9 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -22,6 +22,7 @@ #include "ops/hardswish.hpp" #include "ops/hardtanh.hpp" #include "ops/kv_caching.hpp" +#include "ops/layer_norm.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" diff --git a/include/infinicore/ops/layer_norm.hpp b/include/infinicore/ops/layer_norm.hpp new file mode 100644 index 000000000..da6256b51 --- /dev/null +++ b/include/infinicore/ops/layer_norm.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(LayerNorm, Tensor, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float); + +Tensor layer_norm(const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon = 1e-5f); +void layer_norm_(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon = 1e-5f); +void layer_norm_(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon = 1e-5f); +void layer_norm_for_pybind(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon = 1e-5f); + +} // namespace infinicore::op diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index f90801a63..11b756f83 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -13,6 +13,7 @@ from .hinge_embedding_loss import hinge_embedding_loss from .huber_loss import huber_loss from .interpolate import interpolate +from .layer_norm import layer_norm from .linear import linear from .linear_w8a8i8 import linear_w8a8i8 from .log_softmax import log_softmax @@ -83,4 +84,5 @@ "softplus", "softsign", "huber_loss", + "layer_norm", ] diff --git a/python/infinicore/nn/functional/layer_norm.py b/python/infinicore/nn/functional/layer_norm.py new file mode 100644 index 000000000..b841fdc07 --- /dev/null +++ b/python/infinicore/nn/functional/layer_norm.py @@ -0,0 +1,33 @@ +from typing import List + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Tensor, + bias: Tensor, + eps: float = 1e-5, + *, + out=None, +) -> Tensor: + r"""Apply Layer Normalization.""" + + assert normalized_shape == weight.shape, ( + "normalized_shape does not match weight.shape." + ) + + if out is None: + return Tensor( + _infinicore.layer_norm( + input._underlying, weight._underlying, bias._underlying, eps + ) + ) + + _infinicore.layer_norm_( + out._underlying, input._underlying, weight._underlying, bias._underlying, eps + ) + + return out diff --git a/src/infinicore/nn/layer_norm.cc b/src/infinicore/nn/layer_norm.cc new file mode 100644 index 000000000..45d3b9cb2 --- /dev/null +++ b/src/infinicore/nn/layer_norm.cc @@ -0,0 +1,27 @@ +#include "infinicore/nn/layer_norm.hpp" +#include "infinicore/ops.hpp" +#include +#include + +namespace infinicore::nn { + +LayerNorm::LayerNorm(size_t normalized_shape, double eps, const DataType &dtype, const Device &device) + : normalized_shape_(normalized_shape), + eps_(eps), + dtype_(dtype) { + + device_ = device; + + INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device)); + INFINICORE_NN_PARAMETER_INIT(bias, ({normalized_shape}, dtype_, device)); +} + +Tensor LayerNorm::forward(const Tensor &x) const { + return op::layer_norm(x, weight_, bias_, static_cast(eps_)); +} + +std::string LayerNorm::extra_repr() const { + return "LayerNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; +} + +} // namespace infinicore::nn diff --git a/src/infinicore/ops/layer_norm/layer_norm.cc b/src/infinicore/ops/layer_norm/layer_norm.cc new file mode 100644 index 000000000..6749db68e --- /dev/null +++ b/src/infinicore/ops/layer_norm/layer_norm.cc @@ -0,0 +1,42 @@ +#include "infinicore/ops/layer_norm.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(LayerNorm); + +LayerNorm::LayerNorm(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, standardization, std_deviation, x, weight); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, standardization, std_deviation, x, weight, bias, epsilon); +} + +void LayerNorm::execute(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(LayerNorm, y, standardization, std_deviation, x, weight, bias, epsilon); +} + +Tensor layer_norm(const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); + auto reduced_shape = x->shape(); + reduced_shape.pop_back(); + auto standardization = Tensor::empty(x->shape(), x->dtype(), x->device()); + auto std_deviation = Tensor::empty(reduced_shape, x->dtype(), x->device()); + layer_norm_(y, standardization, std_deviation, x, weight, bias, epsilon); + return y; +} + +void layer_norm_(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + LayerNorm::execute(y, standardization, std_deviation, x, weight, bias, epsilon); +} + +void layer_norm_(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + auto reduced_shape = x->shape(); + reduced_shape.pop_back(); + auto standardization = Tensor::empty(x->shape(), x->dtype(), x->device()); + auto std_deviation = Tensor::empty(reduced_shape, x->dtype(), x->device()); + LayerNorm::execute(y, standardization, std_deviation, x, weight, bias, epsilon); +} + +void layer_norm_for_pybind(Tensor y, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + layer_norm_(y, x, weight, bias, epsilon); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/layer_norm/layer_norm_infiniop.cc b/src/infinicore/ops/layer_norm/layer_norm_infiniop.cc new file mode 100644 index 000000000..2b4a5f335 --- /dev/null +++ b/src/infinicore/ops/layer_norm/layer_norm_infiniop.cc @@ -0,0 +1,65 @@ +#include "infinicore/ops/layer_norm.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::layer_norm_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, LayerNorm, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, y, standardization, std_deviation, x, weight, bias; +}; + +void *plan(Tensor y, Tensor standardization, Tensor std_deviation, const Tensor &x, const Tensor &weight, const Tensor &bias, float epsilon) { + size_t seed = hash_combine(y, standardization, std_deviation, x, weight, bias, epsilon); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, LayerNorm, + seed, + y->desc(), + standardization->desc(), + std_deviation->desc(), + x->desc(), + weight->desc(), + bias->desc(), + epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, LayerNorm, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(standardization), + graph::GraphTensor(std_deviation), + graph::GraphTensor(x), + graph::GraphTensor(weight), + graph::GraphTensor(bias)}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopLayerNorm( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->y->data(), + planned->standardization->data(), + planned->std_deviation->data(), + planned->x->data(), + planned->weight->data(), + planned->bias->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(LayerNorm, &plan, &run, &cleanup); + +} // namespace infinicore::op::layer_norm_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 9e3ac4377..c9c780aad 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -54,6 +54,7 @@ #include "ops/kron.hpp" #include "ops/kthvalue.hpp" #include "ops/kv_caching.hpp" +#include "ops/layer_norm.hpp" #include "ops/ldexp.hpp" #include "ops/lerp.hpp" #include "ops/linear.hpp" @@ -216,6 +217,7 @@ inline void bind(py::module &m) { bind_triplet_margin_loss(m); bind_selu(m); bind_sinh(m); + bind_layer_norm(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/layer_norm.hpp b/src/infinicore/pybind11/ops/layer_norm.hpp new file mode 100644 index 000000000..5ca6a87d3 --- /dev/null +++ b/src/infinicore/pybind11/ops/layer_norm.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include + +#include "infinicore/ops/layer_norm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_layer_norm(py::module &m) { + m.def("layer_norm", + &op::layer_norm, + py::arg("x"), + py::arg("weight"), + py::arg("bias"), + py::arg("epsilon") = 1e-5f, + R"doc(Layer Normalization. + +Args: + x: Input tensor + weight: Scale weights + bias: Bias weights + epsilon: Small constant for numerical stability, default is 1e-5 + +Returns: + Normalized tensor with same shape as input +)doc"); + + m.def("layer_norm_", + &op::layer_norm_for_pybind, + py::arg("y"), + py::arg("x"), + py::arg("weight"), + py::arg("bias"), + py::arg("epsilon") = 1e-5f, + R"doc(In-place Layer Normalization. + +Args: + y: Output tensor + x: Input tensor + weight: Scale weights + bias: Bias weights + epsilon: Small constant for numerical stability, default is 1e-5 +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/devices/metax/metax_kernel_common.h b/src/infiniop/devices/metax/metax_kernel_common.h index 3d2b31a5e..7066a6ebd 100644 --- a/src/infiniop/devices/metax/metax_kernel_common.h +++ b/src/infiniop/devices/metax/metax_kernel_common.h @@ -5,7 +5,9 @@ #include #include #include +#include #else +#include #include #include #include diff --git a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc index 88f386977..7331514e5 100644 --- a/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc +++ b/src/infiniop/ops/layer_norm/cpu/layer_norm_cpu.cc @@ -15,31 +15,132 @@ infiniStatus_t calculate_layer_norm( const Tdata *weight, const Tdata *bias) { + const size_t ndim = info.ndim; + const size_t norm_size = info.normalized_size; // last dim + const size_t othersize = info.othersize; + + const auto &shape = info.input_shape; + const auto &in_strides = info.input_strides; + const auto &out_strides = info.output_strides; + const auto &std_strides = info.input_standardization_strides; + const auto &stddev_strides = info.input_std_deviation_strides; + + // -------- Special fast path: 1D tensor -------- + if (ndim == 1) { + const Tdata *input_ptr = input; + Tdata *output_ptr = output; + Tdata *standard_ptr = input_standardization; + Tdata *std_ptr = input_std_deviation; + + float mean = op::common_cpu::reduce_op::sum( + input_ptr, + norm_size, + in_strides[0]) + / norm_size; + + float sum_sq = op::common_cpu::reduce_op::sumSquared( + input_ptr, + norm_size, + in_strides[0]); + + float var = sum_sq / norm_size - mean * mean; + float std_dev = std::sqrt(var + info.eps); + + *std_ptr = utils::cast(std_dev); + + for (size_t d = 0; d < norm_size; d++) { + float x = utils::cast( + *(input_ptr + d * in_strides[0])); + + float x_std = (x - mean) / std_dev; + + *(standard_ptr + d * std_strides[0]) = utils::cast(x_std); + + float w = utils::cast( + *(weight + d * info.weight_strides[0])); + + float bval = info.bias_exist + ? utils::cast( + *(bias + d * info.bias_strides[0])) + : 0.0f; + + *(output_ptr + d * out_strides[0]) = utils::cast(x_std * w + bval); + } + + return INFINI_STATUS_SUCCESS; + } + + // -------- General case: ndim >= 2 -------- + + // index for dims [0 ... ndim-2] + std::vector idx(ndim - 1, 0); + #pragma omp parallel for - for (int b = 0; b < (int)(info.input_shape[0] * info.input_shape[1]); b++) { - int b0 = b / (int)info.input_shape[1], b1 = b % (int)info.input_shape[1]; - auto output_ptr = output + b0 * info.output_strides[0] + b1 * info.output_strides[1]; - auto input_ptr = input + b0 * info.input_strides[0] + b1 * info.input_strides[1]; - auto standard_ptr = input_standardization + b0 * info.input_standardization_strides[0] + b1 * info.input_standardization_strides[1]; - auto std_ptr = input_std_deviation + b0 * info.input_std_deviation_strides[0] + b1 * info.input_std_deviation_strides[1]; + for (ptrdiff_t b = 0; b < (ptrdiff_t)othersize; b++) { + + // ---- compute base offsets ---- + ptrdiff_t in_offset = 0; + ptrdiff_t out_offset = 0; + ptrdiff_t std_offset = 0; + ptrdiff_t stddev_offset = 0; + + for (size_t d = 0; d < ndim - 1; d++) { + in_offset += idx[d] * in_strides[d]; + out_offset += idx[d] * out_strides[d]; + std_offset += idx[d] * std_strides[d]; + stddev_offset += idx[d] * stddev_strides[d]; + } + + const Tdata *input_ptr = input + in_offset; + Tdata *output_ptr = output + out_offset; + Tdata *standard_ptr = input_standardization + std_offset; + Tdata *std_ptr = input_std_deviation + stddev_offset; + + // ---- mean ---- float mean = op::common_cpu::reduce_op::sum( input_ptr, - info.normalized_size, - info.input_strides[2]) - / info.input_shape[2]; + norm_size, + in_strides[ndim - 1]) + / norm_size; + + // ---- variance ---- float sum_sq = op::common_cpu::reduce_op::sumSquared( input_ptr, - info.normalized_size, - info.input_strides[2]); - float var = sum_sq / (info.normalized_size) - mean * mean; - float std_deviation = std::sqrt(var + info.eps); - *std_ptr = utils::cast(std_deviation); - - for (size_t d = 0; d < info.normalized_size; d++) { - float x_standard = (utils::cast(*(input_ptr + d * info.input_strides[2])) - mean) / std_deviation; - *(standard_ptr + d * info.input_standardization_strides[2]) = utils::cast(x_standard); - *(output_ptr + d * info.output_strides[2]) = utils::cast( - x_standard * utils::cast(*(weight + d * info.weight_strides[0])) + (info.bias_exist ? utils::cast(*(bias + d * info.bias_strides[0])) : float(0))); + norm_size, + in_strides[ndim - 1]); + + float var = sum_sq / norm_size - mean * mean; + float std_dev = std::sqrt(var + info.eps); + + *std_ptr = utils::cast(std_dev); + + // ---- normalize ---- + for (size_t d = 0; d < norm_size; d++) { + float x = utils::cast( + *(input_ptr + d * in_strides[ndim - 1])); + + float x_std = (x - mean) / std_dev; + + *(standard_ptr + d * std_strides[ndim - 1]) = utils::cast(x_std); + + float w = utils::cast( + *(weight + d * info.weight_strides[0])); + + float bval = info.bias_exist + ? utils::cast( + *(bias + d * info.bias_strides[0])) + : 0.0f; + + *(output_ptr + d * out_strides[ndim - 1]) = utils::cast(x_std * w + bval); + } + + // ---- increment multi-dim index (odometer style) ---- + for (int d = (int)ndim - 2; d >= 0; d--) { + idx[d]++; + if (idx[d] < shape[d]) { + break; + } + idx[d] = 0; } } @@ -84,9 +185,14 @@ infiniStatus_t Descriptor::create( return INFINI_STATUS_SUCCESS; } -#define CALCULATE_LAYER_NORM(TDATA) \ - CHECK_STATUS(calculate_layer_norm(_info, \ - (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias)) +#define CALCULATE_LAYER_NORM(TDATA) \ + CHECK_STATUS(calculate_layer_norm(_info, \ + (TDATA *)output, \ + (TDATA *)input_standardization, \ + (TDATA *)input_std_deviation, \ + (const TDATA *)input, \ + (const TDATA *)weight, \ + (const TDATA *)bias)) infiniStatus_t Descriptor::calculate( void *workspace, diff --git a/src/infiniop/ops/layer_norm/metax/layer_norm_metax.maca b/src/infiniop/ops/layer_norm/metax/layer_norm_metax.maca index cfa90368a..b2de1da70 100644 --- a/src/infiniop/ops/layer_norm/metax/layer_norm_metax.maca +++ b/src/infiniop/ops/layer_norm/metax/layer_norm_metax.maca @@ -1,35 +1,29 @@ #include "../../../devices/metax/metax_common.h" -#include "layer_norm_metax.h" -#ifdef ENABLE_METAX_MC_API -#include -#else -#include -#endif #include "../../../devices/metax/metax_kernel_common.h" #include "../../../reduce/cuda/reduce.cuh" #include "../cuda/kernel.cuh" #include "../info.h" +#include "layer_norm_metax.h" namespace op::layer_norm::metax { template INFINIOP_METAX_KERNEL launchKernel( - Tdata * output, - Tdata * input_standardization, - Tdata * input_std_deviation, - const Tdata * input, - const Tdata * weight, - const Tdata * bias, + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, float eps, size_t normalized_size, - const ptrdiff_t* output_strides, - const ptrdiff_t* input_standardization_strides, - const ptrdiff_t* input_std_deviation_strides, - const ptrdiff_t* input_strides, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_standardization_strides, + const ptrdiff_t *input_std_deviation_strides, + const ptrdiff_t *input_strides, ptrdiff_t weight_stride, ptrdiff_t bias_stride, - bool bias_exist -) { + bool bias_exist) { layerNormKernel( output, input_standardization, @@ -45,56 +39,142 @@ INFINIOP_METAX_KERNEL launchKernel( input_strides, weight_stride, bias_stride, - bias_exist - ); + bias_exist); +} + +template +INFINIOP_METAX_KERNEL blockLayernorm( + Tdata *output, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, + float eps, + int dimsize, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + int ndim, + bool bias_exist) { + blockLayernormKernel(output, + input, + weight, + bias, + eps, + dimsize, + output_strides, + input_strides, + shape, + weight_stride, + bias_stride, + ndim, + bias_exist); +} + +template +INFINIOP_METAX_KERNEL warpLayernorm( + Tdata *output, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, + float eps, + int othersize, + int dimsize, + const ptrdiff_t *output_strides, + const ptrdiff_t *input_strides, + const size_t *shape, + ptrdiff_t weight_stride, + ptrdiff_t bias_stride, + int ndim, + bool bias_exist) { + warpLayernormKernel(output, + input, + weight, + bias, + eps, + othersize, + dimsize, + output_strides, + input_strides, + shape, + weight_stride, + bias_stride, + ndim, + bias_exist); } // ----------------------------------- start: call launchKernel ----------------------------------- template infiniStatus_t calculate_layer_norm( const LayerNormInfo &info, - Tdata * output, - Tdata * input_standardization, - Tdata * input_std_deviation, - const Tdata * input, - const Tdata * weight, - const Tdata * bias, + Tdata *output, + Tdata *input_standardization, + Tdata *input_std_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, hcStream_t stream, - void *workspace -) { + void *workspace) { size_t ndim = info.ndim; - ptrdiff_t * input_strides_cuda = reinterpret_cast(workspace); - ptrdiff_t * output_strides_cuda = input_strides_cuda + ndim; - ptrdiff_t * input_standardization_strides_cuda = output_strides_cuda + ndim; - ptrdiff_t * input_std_deviation_strides_cuda = input_standardization_strides_cuda + ndim; + char *workspace_ptr = reinterpret_cast(workspace); + ptrdiff_t *input_strides_cuda = reinterpret_cast(workspace); + ptrdiff_t *output_strides_cuda = input_strides_cuda + ndim; + ptrdiff_t *input_standardization_strides_cuda = output_strides_cuda + ndim; + ptrdiff_t *input_std_deviation_strides_cuda = input_standardization_strides_cuda + ndim; + size_t ptrdiff_array_size = 4 * ndim * sizeof(ptrdiff_t); + size_t *shape_cuda = reinterpret_cast(workspace_ptr + ptrdiff_array_size); CHECK_METAX(hcMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync(input_standardization_strides_cuda, info.input_standardization_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), hcMemcpyHostToDevice, stream)); CHECK_METAX(hcMemcpyAsync(input_std_deviation_strides_cuda, info.input_std_deviation_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), hcMemcpyHostToDevice, stream)); + CHECK_METAX(hcMemcpyAsync(shape_cuda, info.input_shape.data(), sizeof(size_t) * ndim, hcMemcpyHostToDevice, stream)); + int dimsize = (int)info.normalized_size; + int num_blocks = (int)info.othersize; - launchKernel<1, Tdata, float><<>>( - output, - input_standardization, - input_std_deviation, - input, - weight, - bias, - info.eps, - info.normalized_size, - output_strides_cuda, - input_standardization_strides_cuda, - input_std_deviation_strides_cuda, - input_strides_cuda, - info.weight_strides[0], - info.bias_exist ? info.bias_strides[0] : 0, - info.bias_exist - ); + if (dimsize > 1024) { + blockLayernorm + <<>>(output, + input, + weight, + bias, + info.eps, + dimsize, + output_strides_cuda, + input_strides_cuda, + shape_cuda, + info.weight_strides[0], + info.bias_exist ? info.bias_strides[0] : 0, + (int)info.ndim, + info.bias_exist); + } else { + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + + int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + warpLayernorm + <<>>(output, + input, + weight, + bias, + info.eps, + num_blocks, + dimsize, + output_strides_cuda, + input_strides_cuda, + shape_cuda, + info.weight_strides[0], + info.bias_exist ? info.bias_strides[0] : 0, + (int)info.ndim, + info.bias_exist); + } return INFINI_STATUS_SUCCESS; } // ------------------------------------ end: call launchKernel ------------------------------------ - struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -112,10 +192,9 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t input_desc, infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t bias_desc, - float eps -) { + float eps) { auto handle = reinterpret_cast(handle_); -// --------------------- start: check data type and calculate workspace size ---------------------- + // --------------------- start: check data type and calculate workspace size ---------------------- auto dtype = output_desc->dtype(); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); auto result = LayerNormInfo::createLayerNormInfo( @@ -125,62 +204,66 @@ infiniStatus_t Descriptor::create( input_desc, weight_desc, bias_desc, - eps - ); + eps); CHECK_RESULT(result); const LayerNormInfo &info = result.take(); - size_t WorkSpaceSize = sizeof(ptrdiff_t) * input_desc->ndim() * 4; -// ---------------------- end: check data type and calculate workspace size ----------------------- + size_t WorkSpaceSize = output_desc->ndim() * (sizeof(ptrdiff_t) * 4 + sizeof(size_t)); + // ---------------------- end: check data type and calculate workspace size ----------------------- *desc_ptr = new Descriptor( - dtype, std::move(info), WorkSpaceSize, + dtype, std::move(info), WorkSpaceSize, new Opaque{handle->internal()}, - handle->device, handle->device_id - ); + handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } - - infiniStatus_t Descriptor::calculate( - void * workspace, + void *workspace, size_t workspace_size, - void * output, - void * input_standardization, - void * input_std_deviation, - const void * input, - const void * weight, - const void * bias, - void *stream_ -) const { - if (workspace_size < _workspace_size) + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream_) const { + if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } hcStream_t stream = (hcStream_t)stream_; - #define CALCULATE_LAYER_NORM(BLOCK_SIZE, TDATA) \ - calculate_layer_norm(_info, (TDATA *)output, (TDATA *)input_standardization, (TDATA *)input_std_deviation, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias, stream, workspace) - #define CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(BLOCK_SIZE) \ - { \ - if (_info.dtype == INFINI_DTYPE_F16) \ - return CALCULATE_LAYER_NORM(BLOCK_SIZE, half); \ - else if (_info.dtype == INFINI_DTYPE_F32) \ - return CALCULATE_LAYER_NORM(BLOCK_SIZE, float); \ - else if (_info.dtype == INFINI_DTYPE_BF16) \ - return CALCULATE_LAYER_NORM(BLOCK_SIZE, cuda_bfloat16); \ - else \ - return INFINI_STATUS_BAD_TENSOR_DTYPE; \ +#define CALCULATE_LAYER_NORM(BLOCK_SIZE, TDATA) \ + calculate_layer_norm(_info, \ + (TDATA *)output, \ + (TDATA *)input_standardization, \ + (TDATA *)input_std_deviation, \ + (const TDATA *)input, \ + (const TDATA *)weight, \ + (const TDATA *)bias, stream, workspace) + +#define CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_LAYER_NORM(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_LAYER_NORM(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_LAYER_NORM(BLOCK_SIZE, cuda_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } - if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(METAX_BLOCK_SIZE_1024) - else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK(METAX_BLOCK_SIZE_512) - else + } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } return INFINI_STATUS_SUCCESS; - #undef CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK - #undef CALCULATE_LAYER_NORM +#undef CALCULATE_LAYER_NORM_WITH_MATEX_BLOCK +#undef CALCULATE_LAYER_NORM } } // namespace op::layer_norm::metax diff --git a/src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu b/src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu index 5546f5e2f..fac44867b 100644 --- a/src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu +++ b/src/infiniop/ops/layer_norm/nvidia/layer_norm_nvidia.cu @@ -128,6 +128,7 @@ infiniStatus_t calculate_layer_norm( size_t ptrdiff_array_size = 4 * ndim * sizeof(ptrdiff_t); size_t *shape_cuda = reinterpret_cast(workspace_ptr + ptrdiff_array_size); + /// @todo: h2d copy breaks cuda graph, need to optimize this part in the future CHECK_CUDA(cudaMemcpyAsync(input_strides_cuda, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); CHECK_CUDA(cudaMemcpyAsync(output_strides_cuda, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, cudaMemcpyHostToDevice, stream)); CHECK_CUDA(cudaMemcpyAsync(input_standardization_strides_cuda, info.input_standardization_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), cudaMemcpyHostToDevice, stream)); @@ -244,7 +245,7 @@ infiniStatus_t Descriptor::calculate( else if (_info.dtype == INFINI_DTYPE_F32) \ return CALCULATE_LAYER_NORM(BLOCK_SIZE, float); \ else if (_info.dtype == INFINI_DTYPE_BF16) \ - return CALCULATE_LAYER_NORM(BLOCK_SIZE, __nv_bfloat16); \ + return CALCULATE_LAYER_NORM(BLOCK_SIZE, cuda_bfloat16); \ else \ return INFINI_STATUS_BAD_TENSOR_DTYPE; \ } diff --git a/test/infinicore/ops/layer_norm.py b/test/infinicore/ops/layer_norm.py new file mode 100644 index 000000000..b2b158dff --- /dev/null +++ b/test/infinicore/ops/layer_norm.py @@ -0,0 +1,141 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import infinicore +import torch +from framework import ( + BaseOperatorTest, + TensorSpec, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (y_shape, x_shape, w_b_shape, y_strides, x_strides) +_TEST_CASES_DATA = [ + # Basic cases + ((1, 4), (1, 4), (4,), None, None), + ((2, 4), (2, 4), (4,), None, None), + ((2, 2, 4), (2, 2, 4), (4,), None, None), + # Strided cases + ((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)), + # Large tensors + ((16, 2048), (16, 2048), (2048,), None, None), + ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 2e-3, "rtol": 2e-3}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-4}, +} + +# Data types for individual tensors +_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16] + +# EPSILON constant for LayerNorm +_EPSILON = 1e-5 + + +def parse_test_cases(): + """ + Parse LayerNorm test case data and return list of TestCase objects. + Format: (y_shape, x_shape, w_shape, y_strides, x_strides) + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + y_shape = data[0] # Output shape + x_shape = data[1] # Input shape + w_b_shape = data[2] # Weight shape (1D) + y_strides = data[3] if len(data) > 3 else None + x_strides = data[4] if len(data) > 4 else None + + y_supports_inplace = not is_broadcast(y_strides) + + # Generate test cases for all dtype combinations + for input_dtype in _INPUT_DTYPES: + weight_dtype = input_dtype + # Use input dtype tolerance for output + tolerance = _TOLERANCE_MAP.get(input_dtype, {"atol": 1e-5, "rtol": 1e-4}) + + # Create typed tensor specs + x_spec = TensorSpec.from_tensor(x_shape, x_strides, input_dtype) + w_spec = TensorSpec.from_tensor( + w_b_shape, None, weight_dtype + ) # Weight is always contiguous + b_spec = TensorSpec.from_tensor( + w_b_shape, None, weight_dtype + ) # Bias is always contiguous + y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[x_spec, w_spec, b_spec], + kwargs={"epsilon": _EPSILON}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description="LayerNorm - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensor (layer_norm(x, w, out=y)) + if y_supports_inplace: + test_cases.append( + TestCase( + inputs=[x_spec, w_spec, b_spec], + kwargs={"epsilon": _EPSILON}, + output_spec=y_spec, # Specify the output tensor spec + comparison_target="out", + tolerance=tolerance, + description="LayerNorm - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """LayerNorm operator test with simplified implementation""" + + def __init__(self): + super().__init__("LayerNorm") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, x, weight, bias, epsilon=_EPSILON, out=None, **kwargs): + """PyTorch LayerNorm implementation""" + result = torch.nn.functional.layer_norm(x, weight.shape, weight, bias, epsilon) + + if out is not None: + out.copy_(result) + return out + return result + + def infinicore_operator( + self, x, weight, bias, epsilon=_EPSILON, out=None, **kwargs + ): + """InfiniCore LayerNorm implementation""" + import infinicore.nn.functional as F + + return F.layer_norm(x, weight.shape, weight, bias, epsilon, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()