1010#include " acl/acl.h"
1111#include " aclnn/aclnn_base.h"
1212#include " aclnnop/aclnn_cast.h"
13- #include " ascend/common.h"
1413#include " ascend/add_rms_norm/registry.h"
14+ #include " ascend/common.h"
1515#include " ascend/workspace_pool_.h"
1616#include " base/add_rms_norm.h"
1717#include " operator.h"
1818
1919// Forward-declare the generated AscendC kernel launch function.
2020// This symbol is provided by the `no_workspace_kernel` static library
21- // built from `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp`
22- // via `ascendc_library()`.
21+ // built from
22+ // `ascend/custom_kernel/csrc/ops/add_rms_norm/op_kernel/add_rms_norm.cpp` via
23+ // `ascendc_library()`.
2324extern " C" uint32_t aclrtlaunch_add_rms_norm (
24- uint32_t blockDim, void * stream,
25- void * x1, void * x2, void * weight, void * y, void * x_out,
26- int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
27- int64_t formerNum, int64_t formerLength, int64_t tailLength,
28- float eps, int64_t dtypeSize);
25+ uint32_t blockDim, void * stream, void * x1, void * x2, void * weight, void * y,
26+ void * x_out, int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
27+ int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps,
28+ int64_t dtypeSize);
2929
3030namespace infini ::ops {
3131
@@ -62,8 +62,8 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
6262 assert (static_cast <int64_t >(dim_) == dim_length_align_ &&
6363 " Custom AddRmsNorm kernel requires 32-byte aligned last dimension" );
6464
65- total_rows_ = static_cast < int64_t >(batch_size_) *
66- static_cast <int64_t >(nhead_);
65+ total_rows_ =
66+ static_cast < int64_t >(batch_size_) * static_cast <int64_t >(nhead_);
6767
6868 // For fp16 input, weight needs fp32 conversion because the custom
6969 // kernel always reads weight as fp32.
@@ -72,16 +72,15 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
7272 if (needs_weight_cast_) {
7373 // Allocate persistent fp32 weight buffer on device.
7474 size_t fp32_bytes = static_cast <size_t >(dim_) * sizeof (float );
75- aclrtMalloc (&weight_fp32_data_, fp32_bytes,
76- ACL_MEM_MALLOC_NORMAL_ONLY);
75+ aclrtMalloc (&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
7776
7877 // AclTensorCache for the cast source (fp16 weight descriptor).
79- weight_src_cache_ = ascend::AclTensorCache (
80- { static_cast < int64_t >(dim_)}, ACL_FLOAT16, nullptr );
78+ weight_src_cache_ = ascend::AclTensorCache ({ static_cast < int64_t >(dim_)},
79+ ACL_FLOAT16, nullptr );
8180
8281 // AclTensorCache for the cast destination (fp32 weight buffer).
83- weight_dst_cache_ = ascend::AclTensorCache (
84- { static_cast < int64_t >(dim_)}, ACL_FLOAT, weight_fp32_data_);
82+ weight_dst_cache_ = ascend::AclTensorCache ({ static_cast < int64_t >(dim_)},
83+ ACL_FLOAT, weight_fp32_data_);
8584 }
8685 }
8786
@@ -105,8 +104,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
105104 const void * cur_weight = gamma.data ();
106105
107106 if (cur_weight != last_weight_ptr_) {
108- auto t_src =
109- weight_src_cache_.get (const_cast <void *>(cur_weight));
107+ auto t_src = weight_src_cache_.get (const_cast <void *>(cur_weight));
110108 auto t_dst = weight_dst_cache_.get (weight_fp32_data_);
111109
112110 if (!cast_exec_) {
@@ -133,25 +131,17 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
133131 // Block-level tiling: distribute rows across cores.
134132 static constexpr int64_t kMaxBlockDim = 40 ;
135133 int64_t used_cores = std::min (total_rows_, kMaxBlockDim );
136- int64_t former_length =
137- (total_rows_ + used_cores - 1 ) / used_cores;
134+ int64_t former_length = (total_rows_ + used_cores - 1 ) / used_cores;
138135 int64_t tail_length = former_length - 1 ;
139136 int64_t former_num = total_rows_ - tail_length * used_cores;
140137 uint32_t block_dim = static_cast <uint32_t >(used_cores);
141138
142139 // Launch custom AscendC kernel.
143140 aclrtlaunch_add_rms_norm (
144- block_dim, stream,
145- const_cast <void *>(x1.data ()),
146- const_cast <void *>(x2.data ()),
147- weight_fp32,
148- y_out.data (),
149- x_out.data (),
150- total_rows_,
151- static_cast <int64_t >(dim_),
152- dim_length_align_,
153- former_num, former_length, tail_length,
154- eps, dtype_size_);
141+ block_dim, stream, const_cast <void *>(x1.data ()),
142+ const_cast <void *>(x2.data ()), weight_fp32, y_out.data (), x_out.data (),
143+ total_rows_, static_cast <int64_t >(dim_), dim_length_align_, former_num,
144+ former_length, tail_length, eps, dtype_size_);
155145 }
156146
157147 private:
0 commit comments