@@ -8,11 +8,11 @@ class KernelAddRmsNorm {
88 public:
99 __aicore__ inline KernelAddRmsNorm () {}
1010
11- __aicore__ inline void Init (GM_ADDR x1 , GM_ADDR x2 , GM_ADDR weight, GM_ADDR y ,
12- GM_ADDR x_out , int64_t total_rows ,
13- int64_t dim_length , int64_t dim_length_align ,
14- int64_t former_num , int64_t former_length ,
15- int64_t tail_length, float eps ) {
11+ __aicore__ inline void Init (GM_ADDR input , GM_ADDR residual , GM_ADDR weight,
12+ int64_t total_rows , int64_t dim_length ,
13+ int64_t dim_length_align , int64_t former_num ,
14+ int64_t former_length , int64_t tail_length ,
15+ float eps, GM_ADDR out, GM_ADDR residual_out ) {
1616 dim_length_ = dim_length;
1717 dim_length_align_ = dim_length_align;
1818 eps_ = eps;
@@ -31,26 +31,28 @@ class KernelAddRmsNorm {
3131 }
3232
3333 // Global memory pointers.
34- x1_gm_.SetGlobalBuffer ((__gm__ T*)x1 + row_offset * dim_length_align,
35- block_rows_ * dim_length_align);
36- x2_gm_.SetGlobalBuffer ((__gm__ T*)x2 + row_offset * dim_length_align,
37- block_rows_ * dim_length_align);
38- y_gm_.SetGlobalBuffer ((__gm__ T*)y + row_offset * dim_length_align,
39- block_rows_ * dim_length_align);
40- x_out_gm_.SetGlobalBuffer ((__gm__ T*)x_out + row_offset * dim_length_align,
34+ input_gm_.SetGlobalBuffer ((__gm__ T*)input + row_offset * dim_length_align,
4135 block_rows_ * dim_length_align);
36+ residual_gm_.SetGlobalBuffer (
37+ (__gm__ T*)residual + row_offset * dim_length_align,
38+ block_rows_ * dim_length_align);
39+ out_gm_.SetGlobalBuffer ((__gm__ T*)out + row_offset * dim_length_align,
40+ block_rows_ * dim_length_align);
41+ residual_out_gm_.SetGlobalBuffer (
42+ (__gm__ T*)residual_out + row_offset * dim_length_align,
43+ block_rows_ * dim_length_align);
4244 weight_gm_.SetGlobalBuffer ((__gm__ float *)weight, dim_length_align);
4345
4446 int32_t dim_len_align = static_cast <int32_t >(dim_length_align_);
4547
4648 // I/O queues (double-buffered).
47- pipe_.InitBuffer (in_queue_x1_ , kBufferNum ,
49+ pipe_.InitBuffer (in_queue_input_ , kBufferNum ,
4850 dim_len_align * static_cast <int32_t >(sizeof (T)));
49- pipe_.InitBuffer (in_queue_x2_ , kBufferNum ,
51+ pipe_.InitBuffer (in_queue_residual_ , kBufferNum ,
5052 dim_len_align * static_cast <int32_t >(sizeof (T)));
51- pipe_.InitBuffer (out_queue_y_ , kBufferNum ,
53+ pipe_.InitBuffer (out_queue_out_ , kBufferNum ,
5254 dim_len_align * static_cast <int32_t >(sizeof (T)));
53- pipe_.InitBuffer (out_queue_x_out_ , kBufferNum ,
55+ pipe_.InitBuffer (out_queue_residual_out_ , kBufferNum ,
5456 dim_len_align * static_cast <int32_t >(sizeof (T)));
5557
5658 // Weight buffer (fp32, loaded once, reused for all rows).
@@ -103,24 +105,26 @@ class KernelAddRmsNorm {
103105
104106 private:
105107 __aicore__ inline void CopyIn (int64_t row) {
106- AscendC::LocalTensor<T> x1_local = in_queue_x1_.AllocTensor <T>();
107- AscendC::LocalTensor<T> x2_local = in_queue_x2_.AllocTensor <T>();
108+ AscendC::LocalTensor<T> input_local = in_queue_input_.AllocTensor <T>();
109+ AscendC::LocalTensor<T> residual_local =
110+ in_queue_residual_.AllocTensor <T>();
108111 AscendC::DataCopyExtParams params{
109112 1 , static_cast <uint32_t >(dim_length_align_ * sizeof (T)), 0 , 0 , 0 };
110113 AscendC::DataCopyPadExtParams<T> pad{false , 0 , 0 , static_cast <T>(0 )};
111- AscendC::DataCopyPad (x1_local, x1_gm_ [row * dim_length_align_], params ,
112- pad);
113- AscendC::DataCopyPad (x2_local, x2_gm_ [row * dim_length_align_], params ,
114- pad);
115- in_queue_x1_ .EnQue (x1_local );
116- in_queue_x2_ .EnQue (x2_local );
114+ AscendC::DataCopyPad (input_local, input_gm_ [row * dim_length_align_],
115+ params, pad);
116+ AscendC::DataCopyPad (residual_local, residual_gm_ [row * dim_length_align_],
117+ params, pad);
118+ in_queue_input_ .EnQue (input_local );
119+ in_queue_residual_ .EnQue (residual_local );
117120 }
118121
119122 __aicore__ inline void Compute (int64_t row) {
120- AscendC::LocalTensor<T> x1_local = in_queue_x1_.DeQue <T>();
121- AscendC::LocalTensor<T> x2_local = in_queue_x2_.DeQue <T>();
122- AscendC::LocalTensor<T> y_local = out_queue_y_.AllocTensor <T>();
123- AscendC::LocalTensor<T> x_out_local = out_queue_x_out_.AllocTensor <T>();
123+ AscendC::LocalTensor<T> input_local = in_queue_input_.DeQue <T>();
124+ AscendC::LocalTensor<T> residual_local = in_queue_residual_.DeQue <T>();
125+ AscendC::LocalTensor<T> out_local = out_queue_out_.AllocTensor <T>();
126+ AscendC::LocalTensor<T> residual_out_local =
127+ out_queue_residual_out_.AllocTensor <T>();
124128
125129 AscendC::LocalTensor<float > w_local = weight_buf_.Get <float >();
126130 AscendC::LocalTensor<float > r_tmp = reduce_tmp_buf_.Get <float >();
@@ -133,14 +137,16 @@ class KernelAddRmsNorm {
133137 // ---- FP32 path: compute directly. ----
134138
135139 // Step 1: x_out = x1 + x2.
136- AscendC::Add (x_out_local, x1_local, x2_local, dim_len_align);
140+ AscendC::Add (residual_out_local, input_local, residual_local,
141+ dim_len_align);
137142
138- // Step 2: x_out^2 into y_local (reuse output buffer temporarily).
139- AscendC::Mul (y_local, x_out_local, x_out_local, dim_len_align);
143+ // Step 2: x_out^2 into out_local (reuse output buffer temporarily).
144+ AscendC::Mul (out_local, residual_out_local, residual_out_local,
145+ dim_len_align);
140146
141147 // Step 3: ReduceSum(x_out^2) -> s_local[0].
142- // `ReduceSum` may modify `y_local `, but we overwrite it below.
143- AscendC::ReduceSum (s_local, y_local , r_tmp, dim_len_align);
148+ // `ReduceSum` may modify `out_local `, but we overwrite it below.
149+ AscendC::ReduceSum (s_local, out_local , r_tmp, dim_len_align);
144150
145151 // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps).
146152 float sum_val = s_local.GetValue (0 );
@@ -150,25 +156,27 @@ class KernelAddRmsNorm {
150156 float scale = 1 .0f / s_local.GetValue (0 );
151157
152158 // Step 6: y = x_out * scale.
153- AscendC::Muls (y_local, x_out_local , scale, dim_len_align);
159+ AscendC::Muls (out_local, residual_out_local , scale, dim_len_align);
154160
155161 // Step 7: y = y * weight.
156- AscendC::Mul (y_local, y_local , w_local, dim_len_align);
162+ AscendC::Mul (out_local, out_local , w_local, dim_len_align);
157163
158164 } else {
159165 // ---- FP16/BF16 path: cast → fp32 compute → cast back. ----
160166 AscendC::LocalTensor<float > b1 = fp32_buf1_.Get <float >();
161167 AscendC::LocalTensor<float > b2 = fp32_buf2_.Get <float >();
162168
163169 // Cast inputs fp16/bf16 → fp32.
164- AscendC::Cast (b1, x1_local, AscendC::RoundMode::CAST_NONE, dim_len_align);
165- AscendC::Cast (b2, x2_local, AscendC::RoundMode::CAST_NONE, dim_len_align);
170+ AscendC::Cast (b1, input_local, AscendC::RoundMode::CAST_NONE,
171+ dim_len_align);
172+ AscendC::Cast (b2, residual_local, AscendC::RoundMode::CAST_NONE,
173+ dim_len_align);
166174
167175 // Step 1: x_out = x1 + x2 (fp32), stored in b1.
168176 AscendC::Add (b1, b1, b2, dim_len_align);
169177
170178 // Cast `x_out` fp32 → fp16/bf16 for the residual output.
171- AscendC::Cast (x_out_local , b1, AscendC::RoundMode::CAST_RINT,
179+ AscendC::Cast (residual_out_local , b1, AscendC::RoundMode::CAST_RINT,
172180 dim_len_align);
173181
174182 // Step 2: x_out^2 in fp32, stored in b2.
@@ -190,41 +198,43 @@ class KernelAddRmsNorm {
190198 // Step 7: y = y * weight (fp32).
191199 AscendC::Mul (b2, b2, w_local, dim_len_align);
192200
193- AscendC::Cast (y_local, b2, AscendC::RoundMode::CAST_RINT, dim_len_align);
201+ AscendC::Cast (out_local, b2, AscendC::RoundMode::CAST_RINT,
202+ dim_len_align);
194203 }
195204
196- in_queue_x1_ .FreeTensor (x1_local );
197- in_queue_x2_ .FreeTensor (x2_local );
198- out_queue_y_ .EnQue (y_local );
199- out_queue_x_out_ .EnQue (x_out_local );
205+ in_queue_input_ .FreeTensor (input_local );
206+ in_queue_residual_ .FreeTensor (residual_local );
207+ out_queue_out_ .EnQue (out_local );
208+ out_queue_residual_out_ .EnQue (residual_out_local );
200209 }
201210
202211 __aicore__ inline void CopyOut (int64_t row) {
203- AscendC::LocalTensor<T> y_local = out_queue_y_.DeQue <T>();
204- AscendC::LocalTensor<T> x_out_local = out_queue_x_out_.DeQue <T>();
212+ AscendC::LocalTensor<T> out_local = out_queue_out_.DeQue <T>();
213+ AscendC::LocalTensor<T> residual_out_local =
214+ out_queue_residual_out_.DeQue <T>();
205215 AscendC::DataCopyExtParams params{
206216 1 , static_cast <uint32_t >(dim_length_align_ * sizeof (T)), 0 , 0 , 0 };
207- AscendC::DataCopyPad (y_gm_ [row * dim_length_align_], y_local , params);
208- AscendC::DataCopyPad (x_out_gm_ [row * dim_length_align_], x_out_local ,
209- params);
210- out_queue_y_ .FreeTensor (y_local );
211- out_queue_x_out_ .FreeTensor (x_out_local );
217+ AscendC::DataCopyPad (out_gm_ [row * dim_length_align_], out_local , params);
218+ AscendC::DataCopyPad (residual_out_gm_ [row * dim_length_align_],
219+ residual_out_local, params);
220+ out_queue_out_ .FreeTensor (out_local );
221+ out_queue_residual_out_ .FreeTensor (residual_out_local );
212222 }
213223
214224 private:
215225 AscendC::TPipe pipe_;
216- AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum > in_queue_x1_ ;
217- AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum > in_queue_x2_ ;
218- AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum > out_queue_y_ ;
219- AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum > out_queue_x_out_ ;
226+ AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum > in_queue_input_ ;
227+ AscendC::TQue<AscendC::TPosition::VECIN, kBufferNum > in_queue_residual_ ;
228+ AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum > out_queue_out_ ;
229+ AscendC::TQue<AscendC::TPosition::VECOUT, kBufferNum > out_queue_residual_out_ ;
220230
221231 AscendC::TBuf<AscendC::TPosition::VECCALC> weight_buf_;
222232 AscendC::TBuf<AscendC::TPosition::VECCALC> fp32_buf1_;
223233 AscendC::TBuf<AscendC::TPosition::VECCALC> fp32_buf2_;
224234 AscendC::TBuf<AscendC::TPosition::VECCALC> reduce_tmp_buf_;
225235 AscendC::TBuf<AscendC::TPosition::VECCALC> sum_buf_;
226236
227- AscendC::GlobalTensor<T> x1_gm_, x2_gm_, y_gm_, x_out_gm_ ;
237+ AscendC::GlobalTensor<T> input_gm_, residual_gm_, out_gm_, residual_out_gm_ ;
228238 AscendC::GlobalTensor<float > weight_gm_;
229239
230240 int64_t block_rows_;
@@ -238,34 +248,35 @@ class KernelAddRmsNorm {
238248// distinct numeric paths, so dispatch is on the `DataType` tag rather
239249// than the byte size.
240250//
241- // The symbol name `add_rms_norm` must match the `OP_NAME` passed to
242- // `ascendc_add_operator()` / the `aclrtlaunch_*` header; Google C++
243- // Style's PascalCase rule does not apply here (see `op_host/`).
244- extern " C" __global__ __aicore__ void add_rms_norm (
245- GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, GM_ADDR x_out,
246- int64_t total_rows, int64_t dim_length, int64_t dim_length_align,
247- int64_t former_num, int64_t former_length, int64_t tail_length, float eps,
248- int64_t dtype_code) {
251+ // Parameters follow the C2 convention: inputs first, attributes between,
252+ // outputs last. The kernel symbol is prefixed with `aclrtlaunch_` by the
253+ // `AscendC` toolchain, yielding `aclrtlaunch_AddRmsNorm` which matches the
254+ // base `AddRmsNorm` class name.
255+ extern " C" __global__ __aicore__ void AddRmsNorm (
256+ GM_ADDR input, GM_ADDR residual, GM_ADDR weight, int64_t total_rows,
257+ int64_t dim_length, int64_t dim_length_align, int64_t former_num,
258+ int64_t former_length, int64_t tail_length, float eps, int64_t dtype_code,
259+ GM_ADDR out, GM_ADDR residual_out) {
249260 switch (static_cast <infini::ops::DataType>(dtype_code)) {
250261 case infini::ops::DataType::kFloat16 : {
251262 KernelAddRmsNorm<half> op;
252- op.Init (x1, x2 , weight, y, x_out, total_rows, dim_length,
253- dim_length_align, former_num, former_length, tail_length, eps);
263+ op.Init (input, residual , weight, total_rows, dim_length, dim_length_align ,
264+ former_num, former_length, tail_length, eps, out, residual_out );
254265 op.Process ();
255266 break ;
256267 }
257268 case infini::ops::DataType::kBFloat16 : {
258269 KernelAddRmsNorm<bfloat16_t > op;
259- op.Init (x1, x2 , weight, y, x_out, total_rows, dim_length,
260- dim_length_align, former_num, former_length, tail_length, eps);
270+ op.Init (input, residual , weight, total_rows, dim_length, dim_length_align ,
271+ former_num, former_length, tail_length, eps, out, residual_out );
261272 op.Process ();
262273 break ;
263274 }
264275 case infini::ops::DataType::kFloat32 :
265276 default : {
266277 KernelAddRmsNorm<float > op;
267- op.Init (x1, x2 , weight, y, x_out, total_rows, dim_length,
268- dim_length_align, former_num, former_length, tail_length, eps);
278+ op.Init (input, residual , weight, total_rows, dim_length, dim_length_align ,
279+ former_num, former_length, tail_length, eps, out, residual_out );
269280 op.Process ();
270281 break ;
271282 }
0 commit comments