|
1 | | -/* |
2 | | - * Copyright (c) 2025, InfiniTensor. |
3 | | - * All rights reserved. |
4 | | - * |
5 | | - * SPDX-License-Identifier: BSD-3-Clause |
6 | | - */ |
7 | | - |
8 | 1 | #include "aclrtlaunch_add_rms_norm.h" |
9 | 2 | #include "tiling/platform/platform_ascendc.h" |
10 | 3 | #include "torch_kernel_helper.h" |
11 | 4 |
|
12 | 5 | namespace ascend::detail { |
13 | 6 |
|
14 | | -std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2, |
15 | | - const at::Tensor& weight, double eps) { |
| 7 | +std::vector<at::Tensor> AddRmsNorm(const at::Tensor& x1, const at::Tensor& x2, |
| 8 | + const at::Tensor& weight, double eps) { |
16 | 9 | // Input validation. |
17 | | - TORCH_CHECK(x1.dim() > 0, "add_rms_norm: x1 must have at least 1 dimension"); |
| 10 | + TORCH_CHECK(x1.dim() > 0, |
| 11 | + "`AddRmsNorm`: `x1` must have at least 1 dimension."); |
18 | 12 | TORCH_CHECK(x1.sizes() == x2.sizes(), |
19 | | - "add_rms_norm: x1 and x2 must have the same shape"); |
| 13 | + "`AddRmsNorm`: `x1` and `x2` must have the same shape."); |
20 | 14 | TORCH_CHECK(x1.scalar_type() == x2.scalar_type(), |
21 | | - "add_rms_norm: x1 and x2 must have the same dtype"); |
| 15 | + "`AddRmsNorm`: `x1` and `x2` must have the same dtype."); |
22 | 16 | TORCH_CHECK(x1.scalar_type() == at::kHalf || x1.scalar_type() == at::kFloat, |
23 | | - "add_rms_norm: only float16 and float32 are supported, got ", |
24 | | - x1.scalar_type()); |
25 | | - TORCH_CHECK(weight.dim() == 1, "add_rms_norm: weight must be 1-dimensional"); |
26 | | - TORCH_CHECK(weight.size(0) == x1.size(-1), "add_rms_norm: weight size (", |
| 17 | + "`AddRmsNorm`: only `float16` and `float32` are supported; got ", |
| 18 | + x1.scalar_type(), "."); |
| 19 | + TORCH_CHECK(weight.dim() == 1, |
| 20 | + "`AddRmsNorm`: `weight` must be 1-dimensional."); |
| 21 | + TORCH_CHECK(weight.size(0) == x1.size(-1), "`AddRmsNorm`: `weight` size (", |
27 | 22 | weight.size(0), ") must match input last dim (", x1.size(-1), |
28 | | - ")"); |
| 23 | + ")."); |
29 | 24 |
|
30 | 25 | int64_t dim_length = x1.size(-1); |
31 | 26 | int64_t total_rows = x1.numel() / dim_length; |
@@ -62,9 +57,10 @@ std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2, |
62 | 57 | int64_t max_dim_length = (ub_size_limit - 1024) / buffer_coefficient; |
63 | 58 | int64_t fp_align_elements = 32 / 4; |
64 | 59 | max_dim_length = (max_dim_length / fp_align_elements) * fp_align_elements; |
65 | | - TORCH_CHECK(dim_length_align <= max_dim_length, "add_rms_norm: dim_length ", |
66 | | - dim_length, " (aligned ", dim_length_align, |
67 | | - ") exceeds UB capacity (max ", max_dim_length, ")"); |
| 60 | + TORCH_CHECK(dim_length_align <= max_dim_length, |
| 61 | + "`AddRmsNorm`: `dim_length` ", dim_length, " (aligned ", |
| 62 | + dim_length_align, ") exceeds UB capacity (max ", max_dim_length, |
| 63 | + ")."); |
68 | 64 |
|
69 | 65 | // Padding. |
70 | 66 | at::Tensor kernel_input1; |
@@ -109,6 +105,12 @@ std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2, |
109 | 105 | float eps_float = static_cast<float>(eps); |
110 | 106 | int64_t dtype_size_val = dtype_size; |
111 | 107 |
|
| 108 | + // The first arg `add_rms_norm` is the AscendC kernel entry-point name — it |
| 109 | + // must match `ascendc_add_operator(OP_NAME add_rms_norm)` in `CMakeLists.txt`, |
| 110 | + // the `__global__ __aicore__ void add_rms_norm(...)` definition in |
| 111 | + // `op_kernel/`, and the generated `aclrtlaunch_add_rms_norm.h` header. |
| 112 | + // Google C++ Style's PascalCase rule does NOT apply: this identifier is |
| 113 | + // dictated by the AscendC toolchain's symbol convention. |
112 | 114 | EXEC_KERNEL_CMD(add_rms_norm, block_dim, kernel_input1, kernel_input2, |
113 | 115 | weight_float, kernel_output_y, kernel_output_x_out, |
114 | 116 | total_rows, dim_length, dim_length_align, former_num, |
|
0 commit comments