Skip to content

Commit 720234d

Browse files
author
zhangyue
committed
refactor(ascend): address PR #64 review — clean headers, Markdown in TORCH_CHECK, Google C++ naming
- `workspace_pool_.h`: uncomment `<cinttypes>` / `<cstdio>` (needed for `PRIu64` and `fprintf` in the destructor; not transitively available on all platforms). - `device_.h`: switch relative `../device.h` to absolute `device.h` — the historical `src/ascend/device.h` naming collision is no longer relevant. - `custom/{add_rms_norm,rms_norm}/op_host/*.cpp`: drop unneeded BSD-3-Clause headers and switch `TORCH_CHECK` messages to Markdown-backticked identifiers. - `custom/{add_rms_norm,rms_norm}/op_kernel/*.cpp`: drop unneeded BSD-3-Clause headers. - Rename wrapper functions to PascalCase per Google C++ Style: `add_rms_norm` → `AddRmsNorm`, `rms_norm` → `RmsNorm` (ops.h + torch_binding.cpp updated; `torch.ops.npu.rms_norm` registry name unchanged; kernel entry-point names stay snake_case as required by `EXEC_KERNEL_CMD`).
1 parent e419d24 commit 720234d

7 files changed

Lines changed: 44 additions & 59 deletions

File tree

src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
1-
/*
2-
* Copyright (c) 2025, InfiniTensor.
3-
* All rights reserved.
4-
*
5-
* SPDX-License-Identifier: BSD-3-Clause
6-
*/
7-
81
#include "aclrtlaunch_add_rms_norm.h"
92
#include "tiling/platform/platform_ascendc.h"
103
#include "torch_kernel_helper.h"
114

125
namespace ascend::detail {
136

14-
std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2,
15-
const at::Tensor& weight, double eps) {
7+
std::vector<at::Tensor> AddRmsNorm(const at::Tensor& x1, const at::Tensor& x2,
8+
const at::Tensor& weight, double eps) {
169
// Input validation.
17-
TORCH_CHECK(x1.dim() > 0, "add_rms_norm: x1 must have at least 1 dimension");
10+
TORCH_CHECK(x1.dim() > 0,
11+
"`AddRmsNorm`: `x1` must have at least 1 dimension.");
1812
TORCH_CHECK(x1.sizes() == x2.sizes(),
19-
"add_rms_norm: x1 and x2 must have the same shape");
13+
"`AddRmsNorm`: `x1` and `x2` must have the same shape.");
2014
TORCH_CHECK(x1.scalar_type() == x2.scalar_type(),
21-
"add_rms_norm: x1 and x2 must have the same dtype");
15+
"`AddRmsNorm`: `x1` and `x2` must have the same dtype.");
2216
TORCH_CHECK(x1.scalar_type() == at::kHalf || x1.scalar_type() == at::kFloat,
23-
"add_rms_norm: only float16 and float32 are supported, got ",
24-
x1.scalar_type());
25-
TORCH_CHECK(weight.dim() == 1, "add_rms_norm: weight must be 1-dimensional");
26-
TORCH_CHECK(weight.size(0) == x1.size(-1), "add_rms_norm: weight size (",
17+
"`AddRmsNorm`: only `float16` and `float32` are supported; got ",
18+
x1.scalar_type(), ".");
19+
TORCH_CHECK(weight.dim() == 1,
20+
"`AddRmsNorm`: `weight` must be 1-dimensional.");
21+
TORCH_CHECK(weight.size(0) == x1.size(-1), "`AddRmsNorm`: `weight` size (",
2722
weight.size(0), ") must match input last dim (", x1.size(-1),
28-
")");
23+
").");
2924

3025
int64_t dim_length = x1.size(-1);
3126
int64_t total_rows = x1.numel() / dim_length;
@@ -62,9 +57,10 @@ std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2,
6257
int64_t max_dim_length = (ub_size_limit - 1024) / buffer_coefficient;
6358
int64_t fp_align_elements = 32 / 4;
6459
max_dim_length = (max_dim_length / fp_align_elements) * fp_align_elements;
65-
TORCH_CHECK(dim_length_align <= max_dim_length, "add_rms_norm: dim_length ",
66-
dim_length, " (aligned ", dim_length_align,
67-
") exceeds UB capacity (max ", max_dim_length, ")");
60+
TORCH_CHECK(dim_length_align <= max_dim_length,
61+
"`AddRmsNorm`: `dim_length` ", dim_length, " (aligned ",
62+
dim_length_align, ") exceeds UB capacity (max ", max_dim_length,
63+
").");
6864

6965
// Padding.
7066
at::Tensor kernel_input1;
@@ -109,6 +105,12 @@ std::vector<at::Tensor> add_rms_norm(const at::Tensor& x1, const at::Tensor& x2,
109105
float eps_float = static_cast<float>(eps);
110106
int64_t dtype_size_val = dtype_size;
111107

108+
// The first arg `add_rms_norm` is the AscendC kernel entry-point name — it
109+
// must match `ascendc_add_operator(OP_NAME add_rms_norm)` in `CMakeLists.txt`,
110+
// the `__global__ __aicore__ void add_rms_norm(...)` definition in
111+
// `op_kernel/`, and the generated `aclrtlaunch_add_rms_norm.h` header.
112+
// Google C++ Style's PascalCase rule does NOT apply: this identifier is
113+
// dictated by the AscendC toolchain's symbol convention.
112114
EXEC_KERNEL_CMD(add_rms_norm, block_dim, kernel_input1, kernel_input2,
113115
weight_float, kernel_output_y, kernel_output_x_out,
114116
total_rows, dim_length, dim_length_align, former_num,

src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
/*
2-
* Copyright (c) 2025, InfiniTensor.
3-
* All rights reserved.
4-
*
5-
* SPDX-License-Identifier: BSD-3-Clause
6-
*/
7-
81
#include "kernel_operator.h"
92

103
constexpr int32_t BUFFER_NUM = 2;

src/ascend/custom/ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
namespace ascend::detail {
2222

23-
at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight,
24-
double eps);
23+
at::Tensor RmsNorm(const at::Tensor& input, const at::Tensor& weight,
24+
double eps);
2525

2626
} // namespace ascend::detail
2727

src/ascend/custom/rms_norm/op_host/rms_norm.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,22 @@
1-
/*
2-
* Copyright (c) 2025, InfiniTensor.
3-
* All rights reserved.
4-
*
5-
* SPDX-License-Identifier: BSD-3-Clause
6-
*/
7-
81
#include "aclrtlaunch_rms_norm.h"
92
#include "tiling/platform/platform_ascendc.h"
103
#include "torch_kernel_helper.h"
114

125
namespace ascend::detail {
136

14-
at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight,
15-
double eps) {
7+
at::Tensor RmsNorm(const at::Tensor& input, const at::Tensor& weight,
8+
double eps) {
169
// Input validation.
1710
TORCH_CHECK(input.dim() > 0,
18-
"rms_norm: input must have at least 1 dimension");
11+
"`RmsNorm`: `input` must have at least 1 dimension.");
1912
TORCH_CHECK(
2013
input.scalar_type() == at::kHalf || input.scalar_type() == at::kFloat,
21-
"rms_norm: only float16 and float32 are supported, got ",
22-
input.scalar_type());
23-
TORCH_CHECK(weight.dim() == 1, "rms_norm: weight must be 1-dimensional");
24-
TORCH_CHECK(weight.size(0) == input.size(-1), "rms_norm: weight size (",
14+
"`RmsNorm`: only `float16` and `float32` are supported; got ",
15+
input.scalar_type(), ".");
16+
TORCH_CHECK(weight.dim() == 1, "`RmsNorm`: `weight` must be 1-dimensional.");
17+
TORCH_CHECK(weight.size(0) == input.size(-1), "`RmsNorm`: `weight` size (",
2518
weight.size(0), ") must match input last dim (", input.size(-1),
26-
")");
19+
").");
2720

2821
int64_t dim_length = input.size(-1);
2922
int64_t total_rows = input.numel() / dim_length;
@@ -61,9 +54,10 @@ at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight,
6154
// `fp32` alignment.
6255
int64_t fp_align_elements = 32 / 4;
6356
max_dim_length = (max_dim_length / fp_align_elements) * fp_align_elements;
64-
TORCH_CHECK(dim_length_align <= max_dim_length, "rms_norm: dim_length ",
65-
dim_length, " (aligned ", dim_length_align,
66-
") exceeds UB capacity (max ", max_dim_length, ")");
57+
TORCH_CHECK(dim_length_align <= max_dim_length,
58+
"`RmsNorm`: `dim_length` ", dim_length, " (aligned ",
59+
dim_length_align, ") exceeds UB capacity (max ", max_dim_length,
60+
").");
6761

6862
// Padding.
6963
at::Tensor kernel_input;
@@ -100,6 +94,12 @@ at::Tensor rms_norm(const at::Tensor& input, const at::Tensor& weight,
10094
float eps_float = static_cast<float>(eps);
10195
int64_t dtype_size_val = dtype_size;
10296

97+
// The first arg `rms_norm` is the AscendC kernel entry-point name — it
98+
// must match `ascendc_add_operator(OP_NAME rms_norm)` in `CMakeLists.txt`,
99+
// the `__global__ __aicore__ void rms_norm(...)` definition in `op_kernel/`,
100+
// and the generated `aclrtlaunch_rms_norm.h` header. Google C++ Style's
101+
// PascalCase rule does NOT apply: this identifier is dictated by the
102+
// AscendC toolchain's symbol convention.
103103
EXEC_KERNEL_CMD(rms_norm, block_dim, kernel_input, weight_float,
104104
kernel_output, total_rows, dim_length, dim_length_align,
105105
former_num, former_length, tail_length, eps_float,

src/ascend/custom/rms_norm/op_kernel/rms_norm.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
/*
2-
* Copyright (c) 2025, InfiniTensor.
3-
* All rights reserved.
4-
*
5-
* SPDX-License-Identifier: BSD-3-Clause
6-
*/
7-
81
#include "kernel_operator.h"
92

103
constexpr int32_t BUFFER_NUM = 2;

src/ascend/custom/torch_binding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ TORCH_LIBRARY_FRAGMENT(npu, m) {
2626
}
2727

2828
TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) {
29-
m.impl("rms_norm", TORCH_FN(ascend::detail::rms_norm));
29+
m.impl("rms_norm", TORCH_FN(ascend::detail::RmsNorm));
3030
}
3131
} // namespace

src/ascend/device_.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#ifndef INFINI_OPS_ASCEND_DEVICE__H_
22
#define INFINI_OPS_ASCEND_DEVICE__H_
33

4-
// NOTE: Cannot use `#include "device.h"` here — GCC resolves quoted includes
5-
// relative to the current file first, and `src/ascend/` used to contain a
6-
// `device.h`. Use `data_type.h` which transitively pulls in `src/device.h`.
7-
#include "data_type.h"
4+
#include "device.h"
85

96
namespace infini::ops {
107

0 commit comments

Comments
 (0)