Skip to content

Commit d60c180

Browse files
author
zhangyue
committed
refactor(pr66): rename AscendC custom kernels to PascalCase + C2 param order
Addresses Ziminli's comment on `aclrtlaunch_add_rms_norm` forward-decl (#66 discussion 3115868675 / 3129096521): - **函数名格式:** the AscendC kernel entry-points `add_rms_norm` / `rms_norm` are renamed to `AddRmsNorm` / `RmsNorm`. The AscendC toolchain prepends `aclrtlaunch_` on the symbol regardless of case, so the exported names become `aclrtlaunch_AddRmsNorm` / `aclrtlaunch_RmsNorm` — matching the base-class names and `readability-identifier-naming.FunctionCase = CamelCase`. The `NOLINTNEXTLINE(readability-identifier-naming)` shim and the "PascalCase rule does not apply" apology comments go away. - **参数列表顺序 (C2):** reorder parameters to `inputs, attributes, outputs`. Both `.cpp` kernel entry, `KernelAddRmsNorm::Init` / `KernelRmsNorm::Init`, and the `extern "C"` forward-decl in `kernel_custom.h` are updated together, along with the call sites in `operator()`. - **Variable naming (`.cpp` internals):** `x1/x2/y/x_out` → `input/residual/out/residual_out`; `x/y` → `input/out`. Cascaded through member names (`*_gm_`, `*_queue_*`, `*_local`) for consistency — internal to each kernel class, no ABI impact. - **`op_host/*.cpp`:** updated to include the PascalCase generated header `aclrtlaunch_AddRmsNorm.h` / `aclrtlaunch_RmsNorm.h` and to match the reordered `EXEC_KERNEL_CMD` argument list. Verified locally with `.ci/run.py --local`: test_add_rms_norm.py: 108 passed test_rms_norm.py: 72 passed The AscendC toolchain successfully compiles the PascalCase kernel entries and generates matching launch headers — the `aclrtlaunch_<ENTRY>` macro concatenates regardless of case.
1 parent 659ae35 commit d60c180

6 files changed

Lines changed: 174 additions & 171 deletions

File tree

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,16 @@
1414
#include "base/add_rms_norm.h"
1515
#include "operator.h"
1616

17-
// Forward-declare the generated AscendC kernel launch function.
18-
// This symbol is provided by the `no_workspace_kernel` static library
19-
// built from `ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp`
20-
// via `ascendc_library()`.
21-
// `aclrtlaunch_*` symbol name is generated by `ascendc_library()` /
22-
// `ascendc_add_operator()` and cannot be `PascalCase`d.
23-
// NOLINTNEXTLINE(readability-identifier-naming)
24-
extern "C" uint32_t aclrtlaunch_add_rms_norm(
17+
// Forward-declare the generated AscendC kernel launch function. This
18+
// symbol is provided by the `no_workspace_kernel` static library built
19+
// from `ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp` via
20+
// `ascendc_library()`; the `aclrtlaunch_` prefix is prepended by the
21+
// AscendC toolchain to the kernel entry's `extern "C"` name.
22+
extern "C" uint32_t aclrtlaunch_AddRmsNorm(
2523
uint32_t block_dim, void* stream, void* input, void* residual, void* weight,
26-
void* out, void* residual_out, int64_t total_rows, int64_t dim_length,
27-
int64_t dim_length_align, int64_t former_num, int64_t former_length,
28-
int64_t tail_length, float eps, int64_t dtype_code);
24+
int64_t total_rows, int64_t dim_length, int64_t dim_length_align,
25+
int64_t former_num, int64_t former_length, int64_t tail_length, float eps,
26+
int64_t dtype_code, void* out, void* residual_out);
2927

3028
namespace infini::ops {
3129

@@ -142,12 +140,12 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
142140
int64_t former_num = total_rows_ - tail_length * used_cores;
143141
uint32_t block_dim = static_cast<uint32_t>(used_cores);
144142

145-
aclrtlaunch_add_rms_norm(block_dim, stream, const_cast<void*>(input.data()),
146-
const_cast<void*>(residual.data()), weight_fp32,
147-
out.data(), residual_out.data(), total_rows_,
148-
static_cast<int64_t>(dim_), dim_length_align_,
149-
former_num, former_length, tail_length, eps,
150-
static_cast<int64_t>(dtype_));
143+
aclrtlaunch_AddRmsNorm(block_dim, stream, const_cast<void*>(input.data()),
144+
const_cast<void*>(residual.data()), weight_fp32,
145+
total_rows_, static_cast<int64_t>(dim_),
146+
dim_length_align_, former_num, former_length,
147+
tail_length, eps, static_cast<int64_t>(dtype_),
148+
out.data(), residual_out.data());
151149
}
152150

153151
private:

src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "aclrtlaunch_add_rms_norm.h"
1+
#include "aclrtlaunch_AddRmsNorm.h"
22
#include "tiling/platform/platform_ascendc.h"
33
#include "torch_kernel_helper.h"
44

@@ -105,16 +105,14 @@ std::vector<at::Tensor> AddRmsNorm(const at::Tensor& x1, const at::Tensor& x2,
105105
float eps_float = static_cast<float>(eps);
106106
int64_t dtype_size_val = dtype_size;
107107

108-
// The first arg `add_rms_norm` is the AscendC kernel entry-point name — it
109-
// must match `ascendc_add_operator(OP_NAME add_rms_norm)` in `CMakeLists.txt`,
110-
// the `__global__ __aicore__ void add_rms_norm(...)` definition in
111-
// `op_kernel/`, and the generated `aclrtlaunch_add_rms_norm.h` header.
112-
// Google C++ Style's PascalCase rule does NOT apply: this identifier is
113-
// dictated by the AscendC toolchain's symbol convention.
114-
EXEC_KERNEL_CMD(add_rms_norm, block_dim, kernel_input1, kernel_input2,
115-
weight_float, kernel_output_y, kernel_output_x_out,
116-
total_rows, dim_length, dim_length_align, former_num,
117-
former_length, tail_length, eps_float, dtype_size_val);
108+
// The first arg `AddRmsNorm` is the AscendC kernel entry-point name — it
109+
// must match the `__global__ __aicore__ void AddRmsNorm(...)` definition
110+
// in `op_kernel/` and the generated `aclrtlaunch_AddRmsNorm.h` header.
111+
// Parameter order follows the base class: inputs, attributes, outputs.
112+
EXEC_KERNEL_CMD(AddRmsNorm, block_dim, kernel_input1, kernel_input2,
113+
weight_float, total_rows, dim_length, dim_length_align,
114+
former_num, former_length, tail_length, eps_float,
115+
dtype_size_val, kernel_output_y, kernel_output_x_out);
118116

119117
// Remove padding and reshape back to original shape.
120118
at::Tensor output_y = kernel_output_y;

src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp

Lines changed: 80 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ class KernelAddRmsNorm {
88
public:
99
__aicore__ inline KernelAddRmsNorm() {}
1010

11-
__aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y,
12-
GM_ADDR x_out, int64_t total_rows,
13-
int64_t dim_length, int64_t dim_length_align,
14-
int64_t former_num, int64_t former_length,
15-
int64_t tail_length, float eps) {
11+
__aicore__ inline void Init(GM_ADDR input, GM_ADDR residual, GM_ADDR weight,
12+
int64_t total_rows, int64_t dim_length,
13+
int64_t dim_length_align, int64_t former_num,
14+
int64_t former_length, int64_t tail_length,
15+
float eps, GM_ADDR out, GM_ADDR residual_out) {
1616
dim_length_ = dim_length;
1717
dim_length_align_ = dim_length_align;
1818
eps_ = eps;
@@ -31,26 +31,28 @@ class KernelAddRmsNorm {
3131
}
3232

3333
// Global memory pointers.
34-
x1_gm_.SetGlobalBuffer((__gm__ T*)x1 + row_offset * dim_length_align,
35-
block_rows_ * dim_length_align);
36-
x2_gm_.SetGlobalBuffer((__gm__ T*)x2 + row_offset * dim_length_align,
37-
block_rows_ * dim_length_align);
38-
y_gm_.SetGlobalBuffer((__gm__ T*)y + row_offset * dim_length_align,
39-
block_rows_ * dim_length_align);
40-
x_out_gm_.SetGlobalBuffer((__gm__ T*)x_out + row_offset * dim_length_align,
34+
input_gm_.SetGlobalBuffer((__gm__ T*)input + row_offset * dim_length_align,
4135
block_rows_ * dim_length_align);
36+
residual_gm_.SetGlobalBuffer(
37+
(__gm__ T*)residual + row_offset * dim_length_align,
38+
block_rows_ * dim_length_align);
39+
out_gm_.SetGlobalBuffer((__gm__ T*)out + row_offset * dim_length_align,
40+
block_rows_ * dim_length_align);
41+
residual_out_gm_.SetGlobalBuffer(
42+
(__gm__ T*)residual_out + row_offset * dim_length_align,
43+
block_rows_ * dim_length_align);
4244
weight_gm_.SetGlobalBuffer((__gm__ float*)weight, dim_length_align);
4345

4446
int32_t dim_len_align = static_cast<int32_t>(dim_length_align_);
4547

4648
// I/O queues (double-buffered).
47-
pipe_.InitBuffer(in_queue_x1_, kBufferNum,
49+
pipe_.InitBuffer(in_queue_input_, kBufferNum,
4850
dim_len_align * static_cast<int32_t>(sizeof(T)));
49-
pipe_.InitBuffer(in_queue_x2_, kBufferNum,
51+
pipe_.InitBuffer(in_queue_residual_, kBufferNum,
5052
dim_len_align * static_cast<int32_t>(sizeof(T)));
51-
pipe_.InitBuffer(out_queue_y_, kBufferNum,
53+
pipe_.InitBuffer(out_queue_out_, kBufferNum,
5254
dim_len_align * static_cast<int32_t>(sizeof(T)));
53-
pipe_.InitBuffer(out_queue_x_out_, kBufferNum,
55+
pipe_.InitBuffer(out_queue_residual_out_, kBufferNum,
5456
dim_len_align * static_cast<int32_t>(sizeof(T)));
5557

5658
// Weight buffer (fp32, loaded once, reused for all rows).
@@ -103,24 +105,26 @@ class KernelAddRmsNorm {
103105

104106
private:
105107
__aicore__ inline void CopyIn(int64_t row) {
106-
AscendC::LocalTensor<T> x1_local = in_queue_x1_.AllocTensor<T>();
107-
AscendC::LocalTensor<T> x2_local = in_queue_x2_.AllocTensor<T>();
108+
AscendC::LocalTensor<T> input_local = in_queue_input_.AllocTensor<T>();
109+
AscendC::LocalTensor<T> residual_local =
110+
in_queue_residual_.AllocTensor<T>();
108111
AscendC::DataCopyExtParams params{
109112
1, static_cast<uint32_t>(dim_length_align_ * sizeof(T)), 0, 0, 0};
110113
AscendC::DataCopyPadExtParams<T> pad{false, 0, 0, static_cast<T>(0)};
111-
AscendC::DataCopyPad(x1_local, x1_gm_[row * dim_length_align_], params,
112-
pad);
113-
AscendC::DataCopyPad(x2_local, x2_gm_[row * dim_length_align_], params,
114-
pad);
115-
in_queue_x1_.EnQue(x1_local);
116-
in_queue_x2_.EnQue(x2_local);
114+
AscendC::DataCopyPad(input_local, input_gm_[row * dim_length_align_],
115+
params, pad);
116+
AscendC::DataCopyPad(residual_local, residual_gm_[row * dim_length_align_],
117+
params, pad);
118+
in_queue_input_.EnQue(input_local);
119+
in_queue_residual_.EnQue(residual_local);
117120
}
118121

119122
__aicore__ inline void Compute(int64_t row) {
120-
AscendC::LocalTensor<T> x1_local = in_queue_x1_.DeQue<T>();
121-
AscendC::LocalTensor<T> x2_local = in_queue_x2_.DeQue<T>();
122-
AscendC::LocalTensor<T> y_local = out_queue_y_.AllocTensor<T>();
123-
AscendC::LocalTensor<T> x_out_local = out_queue_x_out_.AllocTensor<T>();
123+
AscendC::LocalTensor<T> input_local = in_queue_input_.DeQue<T>();
124+
AscendC::LocalTensor<T> residual_local = in_queue_residual_.DeQue<T>();
125+
AscendC::LocalTensor<T> out_local = out_queue_out_.AllocTensor<T>();
126+
AscendC::LocalTensor<T> residual_out_local =
127+
out_queue_residual_out_.AllocTensor<T>();
124128

125129
AscendC::LocalTensor<float> w_local = weight_buf_.Get<float>();
126130
AscendC::LocalTensor<float> r_tmp = reduce_tmp_buf_.Get<float>();
@@ -133,14 +137,16 @@ class KernelAddRmsNorm {
133137
// ---- FP32 path: compute directly. ----
134138

135139
// Step 1: x_out = x1 + x2.
136-
AscendC::Add(x_out_local, x1_local, x2_local, dim_len_align);
140+
AscendC::Add(residual_out_local, input_local, residual_local,
141+
dim_len_align);
137142

138-
// Step 2: x_out^2 into y_local (reuse output buffer temporarily).
139-
AscendC::Mul(y_local, x_out_local, x_out_local, dim_len_align);
143+
// Step 2: x_out^2 into out_local (reuse output buffer temporarily).
144+
AscendC::Mul(out_local, residual_out_local, residual_out_local,
145+
dim_len_align);
140146

141147
// Step 3: ReduceSum(x_out^2) -> s_local[0].
142-
// `ReduceSum` may modify `y_local`, but we overwrite it below.
143-
AscendC::ReduceSum(s_local, y_local, r_tmp, dim_len_align);
148+
// `ReduceSum` may modify `out_local`, but we overwrite it below.
149+
AscendC::ReduceSum(s_local, out_local, r_tmp, dim_len_align);
144150

145151
// Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps).
146152
float sum_val = s_local.GetValue(0);
@@ -150,25 +156,27 @@ class KernelAddRmsNorm {
150156
float scale = 1.0f / s_local.GetValue(0);
151157

152158
// Step 6: y = x_out * scale.
153-
AscendC::Muls(y_local, x_out_local, scale, dim_len_align);
159+
AscendC::Muls(out_local, residual_out_local, scale, dim_len_align);
154160

155161
// Step 7: y = y * weight.
156-
AscendC::Mul(y_local, y_local, w_local, dim_len_align);
162+
AscendC::Mul(out_local, out_local, w_local, dim_len_align);
157163

158164
} else {
159165
// ---- FP16/BF16 path: cast → fp32 compute → cast back. ----
160166
AscendC::LocalTensor<float> b1 = fp32_buf1_.Get<float>();
161167
AscendC::LocalTensor<float> b2 = fp32_buf2_.Get<float>();
162168

163169
// Cast inputs fp16/bf16 → fp32.
164-
AscendC::Cast(b1, x1_local, AscendC::RoundMode::CAST_NONE, dim_len_align);
165-
AscendC::Cast(b2, x2_local, AscendC::RoundMode::CAST_NONE, dim_len_align);
170+
AscendC::Cast(b1, input_local, AscendC::RoundMode::CAST_NONE,
171+
dim_len_align);
172+
AscendC::Cast(b2, residual_local, AscendC::RoundMode::CAST_NONE,
173+
dim_len_align);
166174

167175
// Step 1: x_out = x1 + x2 (fp32), stored in b1.
168176
AscendC::Add(b1, b1, b2, dim_len_align);
169177

170178
// Cast `x_out` fp32 → fp16/bf16 for the residual output.
171-
AscendC::Cast(x_out_local, b1, AscendC::RoundMode::CAST_RINT,
179+
AscendC::Cast(residual_out_local, b1, AscendC::RoundMode::CAST_RINT,
172180
dim_len_align);
173181

174182
// Step 2: x_out^2 in fp32, stored in b2.
@@ -190,41 +198,43 @@ class KernelAddRmsNorm {
190198
// Step 7: y = y * weight (fp32).
191199
AscendC::Mul(b2, b2, w_local, dim_len_align);
192200

193-
AscendC::Cast(y_local, b2, AscendC::RoundMode::CAST_RINT, dim_len_align);
201+
AscendC::Cast(out_local, b2, AscendC::RoundMode::CAST_RINT,
202+
dim_len_align);
194203
}
195204

196-
in_queue_x1_.FreeTensor(x1_local);
197-
in_queue_x2_.FreeTensor(x2_local);
198-
out_queue_y_.EnQue(y_local);
199-
out_queue_x_out_.EnQue(x_out_local);
205+
in_queue_input_.FreeTensor(input_local);
206+
in_queue_residual_.FreeTensor(residual_local);
207+
out_queue_out_.EnQue(out_local);
208+
out_queue_residual_out_.EnQue(residual_out_local);
200209
}
201210

202211
__aicore__ inline void CopyOut(int64_t row) {
203-
AscendC::LocalTensor<T> y_local = out_queue_y_.DeQue<T>();
204-
AscendC::LocalTensor<T> x_out_local = out_queue_x_out_.DeQue<T>();
212+
AscendC::LocalTensor<T> out_local = out_queue_out_.DeQue<T>();
213+
AscendC::LocalTensor<T> residual_out_local =
214+
out_queue_residual_out_.DeQue<T>();
205215
AscendC::DataCopyExtParams params{
206216
1, static_cast<uint32_t>(dim_length_align_ * sizeof(T)), 0, 0, 0};
207-
AscendC::DataCopyPad(y_gm_[row * dim_length_align_], y_local, params);
208-
AscendC::DataCopyPad(x_out_gm_[row * dim_length_align_], x_out_local,
209-
params);
210-
out_queue_y_.FreeTensor(y_local);
211-
out_queue_x_out_.FreeTensor(x_out_local);
217+
AscendC::DataCopyPad(out_gm_[row * dim_length_align_], out_local, params);
218+
AscendC::DataCopyPad(residual_out_gm_[row * dim_length_align_],
219+
residual_out_local, params);
220+
out_queue_out_.FreeTensor(out_local);
221+
out_queue_residual_out_.FreeTensor(residual_out_local);
212222
}
213223

214224
private:
215225
AscendC::TPipe pipe_;
216-
AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum> in_queue_x1_;
217-
AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum> in_queue_x2_;
218-
AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum> out_queue_y_;
219-
AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum> out_queue_x_out_;
226+
AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum> in_queue_input_;
227+
AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum> in_queue_residual_;
228+
AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum> out_queue_out_;
229+
AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum> out_queue_residual_out_;
220230

221231
AscendC::TBuf<AscendC::TPosition::VECCALC> weight_buf_;
222232
AscendC::TBuf<AscendC::TPosition::VECCALC> fp32_buf1_;
223233
AscendC::TBuf<AscendC::TPosition::VECCALC> fp32_buf2_;
224234
AscendC::TBuf<AscendC::TPosition::VECCALC> reduce_tmp_buf_;
225235
AscendC::TBuf<AscendC::TPosition::VECCALC> sum_buf_;
226236

227-
AscendC::GlobalTensor<T> x1_gm_, x2_gm_, y_gm_, x_out_gm_;
237+
AscendC::GlobalTensor<T> input_gm_, residual_gm_, out_gm_, residual_out_gm_;
228238
AscendC::GlobalTensor<float> weight_gm_;
229239

230240
int64_t block_rows_;
@@ -238,34 +248,35 @@ class KernelAddRmsNorm {
238248
// distinct numeric paths, so dispatch is on the `DataType` tag rather
239249
// than the byte size.
240250
//
241-
// The symbol name `add_rms_norm` must match the `OP_NAME` passed to
242-
// `ascendc_add_operator()` / the `aclrtlaunch_*` header; Google C++
243-
// Style's PascalCase rule does not apply here (see `op_host/`).
244-
extern "C" __global__ __aicore__ void add_rms_norm(
245-
GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, GM_ADDR x_out,
246-
int64_t total_rows, int64_t dim_length, int64_t dim_length_align,
247-
int64_t former_num, int64_t former_length, int64_t tail_length, float eps,
248-
int64_t dtype_code) {
251+
// Parameters follow the C2 convention: inputs first, attributes between,
252+
// outputs last. The kernel symbol is prefixed with `aclrtlaunch_` by the
253+
// `AscendC` toolchain, yielding `aclrtlaunch_AddRmsNorm` which matches the
254+
// base `AddRmsNorm` class name.
255+
extern "C" __global__ __aicore__ void AddRmsNorm(
256+
GM_ADDR input, GM_ADDR residual, GM_ADDR weight, int64_t total_rows,
257+
int64_t dim_length, int64_t dim_length_align, int64_t former_num,
258+
int64_t former_length, int64_t tail_length, float eps, int64_t dtype_code,
259+
GM_ADDR out, GM_ADDR residual_out) {
249260
switch (static_cast<infini::ops::DataType>(dtype_code)) {
250261
case infini::ops::DataType::kFloat16: {
251262
KernelAddRmsNorm<half> op;
252-
op.Init(x1, x2, weight, y, x_out, total_rows, dim_length,
253-
dim_length_align, former_num, former_length, tail_length, eps);
263+
op.Init(input, residual, weight, total_rows, dim_length, dim_length_align,
264+
former_num, former_length, tail_length, eps, out, residual_out);
254265
op.Process();
255266
break;
256267
}
257268
case infini::ops::DataType::kBFloat16: {
258269
KernelAddRmsNorm<bfloat16_t> op;
259-
op.Init(x1, x2, weight, y, x_out, total_rows, dim_length,
260-
dim_length_align, former_num, former_length, tail_length, eps);
270+
op.Init(input, residual, weight, total_rows, dim_length, dim_length_align,
271+
former_num, former_length, tail_length, eps, out, residual_out);
261272
op.Process();
262273
break;
263274
}
264275
case infini::ops::DataType::kFloat32:
265276
default: {
266277
KernelAddRmsNorm<float> op;
267-
op.Init(x1, x2, weight, y, x_out, total_rows, dim_length,
268-
dim_length_align, former_num, former_length, tail_length, eps);
278+
op.Init(input, residual, weight, total_rows, dim_length, dim_length_align,
279+
former_num, former_length, tail_length, eps, out, residual_out);
269280
op.Process();
270281
break;
271282
}

src/ascend/custom/rms_norm/op_host/rms_norm.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "aclrtlaunch_rms_norm.h"
1+
#include "aclrtlaunch_RmsNorm.h"
22
#include "tiling/platform/platform_ascendc.h"
33
#include "torch_kernel_helper.h"
44

@@ -94,16 +94,13 @@ at::Tensor RmsNorm(const at::Tensor& input, const at::Tensor& weight,
9494
float eps_float = static_cast<float>(eps);
9595
int64_t dtype_size_val = dtype_size;
9696

97-
// The first arg `rms_norm` is the AscendC kernel entry-point name — it
98-
// must match `ascendc_add_operator(OP_NAME rms_norm)` in `CMakeLists.txt`,
99-
// the `__global__ __aicore__ void rms_norm(...)` definition in `op_kernel/`,
100-
// and the generated `aclrtlaunch_rms_norm.h` header. Google C++ Style's
101-
// PascalCase rule does NOT apply: this identifier is dictated by the
102-
// AscendC toolchain's symbol convention.
103-
EXEC_KERNEL_CMD(rms_norm, block_dim, kernel_input, weight_float,
104-
kernel_output, total_rows, dim_length, dim_length_align,
105-
former_num, former_length, tail_length, eps_float,
106-
dtype_size_val);
97+
// The first arg `RmsNorm` is the AscendC kernel entry-point name — it
98+
// must match the `__global__ __aicore__ void RmsNorm(...)` definition in
99+
// `op_kernel/` and the generated `aclrtlaunch_RmsNorm.h` header.
100+
// Parameter order follows the base class: inputs, attributes, outputs.
101+
EXEC_KERNEL_CMD(RmsNorm, block_dim, kernel_input, weight_float, total_rows,
102+
dim_length, dim_length_align, former_num, former_length,
103+
tail_length, eps_float, dtype_size_val, kernel_output);
107104

108105
// Remove padding and reshape back to original shape.
109106
at::Tensor output = kernel_output;

0 commit comments

Comments
 (0)