Skip to content

Commit e38d08b

Browse files
author
zhangyue
committed
feat(ascend): op-norm-rope group — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding
Seven layer-level Ascend operators: | op | impl | |---|---| | Swiglu | aclnnSilu + aclnnMul (decomposed); `kernel_fused.h` wraps fused swiglu where available | | SiluAndMul | custom AscendC kernel | | CausalSoftmax | aclnnSoftmax + pre-computed mask | | RmsNorm | aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h) | | AddRmsNorm | 3 impls: decomposed aclnnAdd+aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h) | | ApplyRotaryPosEmb | aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam (kernel_atb.h) | | RotaryEmbedding | **3 impls**: aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam with both neox/interleave (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) | Bundles the RotaryEmbedding API alignment: `query_out` / `key_out` are now `std::optional<Tensor>` — omitted → inplace on `query` / `key` (matches vLLM `RotaryEmbedding.forward(positions, query, key)`). New `src/base/<op>.h`: apply_rotary_pos_emb, silu_and_mul. Modified: add_rms_norm (constructor signature alignment), rotary_embedding (optional query_out/key_out).
1 parent 13cf84a commit e38d08b

22 files changed

Lines changed: 3683 additions & 41 deletions

src/ascend/add_rms_norm/kernel.h

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_
3+
4+
#include <vector>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnn_add.h"
9+
#include "aclnn_rms_norm.h"
10+
#include "ascend/common.h"
11+
#include "ascend/workspace_pool_.h"
12+
#include "base/add_rms_norm.h"
13+
#include "operator.h"
14+
15+
namespace infini::ops {
16+
17+
// Decomposed implementation: aclnnAdd + aclnnRmsNorm.
18+
//
19+
// The fused aclnnAddRmsNorm API has ~200 us host-side launch overhead that
20+
// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls
21+
// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible
22+
// NPU-side impact for inference tensor sizes.
23+
template <>
24+
class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
25+
public:
26+
Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps,
27+
Tensor y_out, Tensor x_out)
28+
: AddRmsNorm(x1, x2, gamma, eps, y_out, x_out),
29+
x1_cache_(x1),
30+
x2_cache_(x2),
31+
gamma_cache_(gamma),
32+
y_out_cache_(y_out),
33+
x_out_cache_(x_out) {
34+
// Alpha scalar for aclnnAdd (x_out = x1 + 1.0 * x2).
35+
alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT);
36+
37+
// aclnnRmsNorm writes rstd as a required side output.
38+
// Size computed here; buffer obtained from pool in `operator()`.
39+
rstd_shape_ = {static_cast<int64_t>(batch_size_),
40+
static_cast<int64_t>(nhead_)};
41+
rstd_size_ = batch_size_ * nhead_ * sizeof(float);
42+
}
43+
44+
~Operator() {
45+
if (!ascend::IsAclRuntimeAlive()) return;
46+
47+
// Null cached descriptors — see `AclTensorCache::release()`.
48+
x1_cache_.release();
49+
x2_cache_.release();
50+
gamma_cache_.release();
51+
y_out_cache_.release();
52+
x_out_cache_.release();
53+
54+
// `rstd_tensor_` leaks with `norm_exec_` at shutdown (see `64c367c`).
55+
if (alpha_) aclDestroyScalar(alpha_);
56+
}
57+
58+
void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
59+
float eps, Tensor y_out, Tensor x_out) const override {
60+
auto t_x1 = x1_cache_.get(const_cast<void*>(x1.data()));
61+
auto t_x2 = x2_cache_.get(const_cast<void*>(x2.data()));
62+
auto t_gamma = gamma_cache_.get(const_cast<void*>(gamma.data()));
63+
auto t_y_out = y_out_cache_.get(y_out.data());
64+
auto t_x_out = x_out_cache_.get(x_out.data());
65+
auto stream = static_cast<aclrtStream>(stream_);
66+
67+
// Step 1: x_out = x1 + x2.
68+
if (!add_exec_) {
69+
aclnnAddGetWorkspaceSize(t_x1, t_x2, alpha_, t_x_out, &add_ws_,
70+
&add_exec_);
71+
aclSetAclOpExecutorRepeatable(add_exec_);
72+
} else {
73+
aclSetInputTensorAddr(add_exec_, 0, t_x1, const_cast<void*>(x1.data()));
74+
aclSetInputTensorAddr(add_exec_, 1, t_x2, const_cast<void*>(x2.data()));
75+
aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data());
76+
}
77+
auto& add_arena = ascend::GetWorkspacePool().Ensure(stream, add_ws_);
78+
aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream);
79+
80+
// Obtain shared rstd buffer from pool.
81+
auto& rstd_arena =
82+
ascend::GetWorkspacePool().Ensure(stream, rstd_size_, "temp");
83+
84+
// Lazily create rstd tensor descriptor on first call.
85+
if (!rstd_tensor_) {
86+
rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT,
87+
/*strides=*/nullptr, 0, ACL_FORMAT_ND,
88+
rstd_shape_.data(), 2, rstd_arena.buf);
89+
} else {
90+
aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf);
91+
}
92+
93+
// Step 2: y_out = rms_norm(x_out, gamma, eps).
94+
if (!norm_exec_) {
95+
aclnnRmsNormGetWorkspaceSize(t_x_out, t_gamma, eps, t_y_out, rstd_tensor_,
96+
&norm_ws_, &norm_exec_);
97+
aclSetAclOpExecutorRepeatable(norm_exec_);
98+
} else {
99+
aclSetInputTensorAddr(norm_exec_, 0, t_x_out, x_out.data());
100+
aclSetInputTensorAddr(norm_exec_, 1, t_gamma,
101+
const_cast<void*>(gamma.data()));
102+
aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data());
103+
aclSetOutputTensorAddr(norm_exec_, 1, rstd_tensor_, rstd_arena.buf);
104+
}
105+
auto& norm_arena = ascend::GetWorkspacePool().Ensure(stream, norm_ws_);
106+
aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream);
107+
}
108+
109+
private:
110+
mutable ascend::AclTensorCache x1_cache_;
111+
112+
mutable ascend::AclTensorCache x2_cache_;
113+
114+
mutable ascend::AclTensorCache gamma_cache_;
115+
116+
mutable ascend::AclTensorCache y_out_cache_;
117+
118+
mutable ascend::AclTensorCache x_out_cache_;
119+
120+
float alpha_storage_ = 1.0f;
121+
122+
aclScalar* alpha_ = nullptr;
123+
124+
std::vector<int64_t> rstd_shape_;
125+
126+
uint64_t rstd_size_ = 0;
127+
128+
mutable aclTensor* rstd_tensor_ = nullptr;
129+
130+
mutable aclOpExecutor* add_exec_ = nullptr;
131+
132+
mutable uint64_t add_ws_ = 0;
133+
134+
mutable aclOpExecutor* norm_exec_ = nullptr;
135+
136+
mutable uint64_t norm_ws_ = 0;
137+
};
138+
139+
} // namespace infini::ops
140+
141+
#endif
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_
2+
#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_
3+
4+
#ifdef INFINI_HAS_CUSTOM_KERNELS
5+
6+
#include <algorithm>
7+
#include <cstdint>
8+
#include <vector>
9+
10+
#include "acl/acl.h"
11+
#include "aclnn/aclnn_base.h"
12+
#include "aclnnop/aclnn_cast.h"
13+
#include "ascend/common.h"
14+
#include "ascend/workspace_pool_.h"
15+
#include "base/add_rms_norm.h"
16+
#include "operator.h"
17+
18+
// Forward-declare the generated AscendC kernel launch function.
19+
// This symbol is provided by the `no_workspace_kernel` static library
20+
// built from `ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp`
21+
// via `ascendc_library()`.
22+
extern "C" uint32_t aclrtlaunch_add_rms_norm(
23+
uint32_t blockDim, void* stream, void* x1, void* x2, void* weight, void* y,
24+
void* x_out, int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign,
25+
int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps,
26+
int64_t dtypeSize);
27+
28+
namespace infini::ops {
29+
30+
// Custom AscendC fused AddRmsNorm kernel (implementation index 2).
31+
//
32+
// A single-kernel implementation that computes x_out = x1 + x2 followed by
33+
// y = rms_norm(x_out, gamma, eps) in one launch, avoiding the decomposed
34+
// aclnnAdd + aclnnRmsNorm calls (index 0) or the fused aclnnAddRmsNorm call
35+
// (index 1). Migrated from the custom RmsNorm kernel (index 1 of RmsNorm).
36+
//
37+
// Select via `implementation_index=2` in Python:
38+
// infini.ops.add_rms_norm(x1, x2, gamma, eps, y_out, x_out,
39+
// implementation_index=2, stream=s)
40+
//
41+
// Requirements:
42+
// - Input last dimension must be 32-byte aligned (divisible by 16 for fp16
43+
// or 8 for fp32). All standard LLM hidden dimensions satisfy this.
44+
// - Weight must have the same dtype as input.
45+
// - The custom kernel binary must be linked (`BUILD_CUSTOM_KERNEL=ON`).
46+
template <>
47+
class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
48+
public:
49+
Operator(const Tensor x1, const Tensor x2, const Tensor gamma, float eps,
50+
Tensor y_out, Tensor x_out)
51+
: AddRmsNorm(x1, x2, gamma, eps, y_out, x_out) {
52+
// Dtype size in bytes.
53+
dtype_size_ = (x1.dtype() == DataType::kFloat16) ? 2 : 4;
54+
55+
// Alignment check (32-byte boundary).
56+
int64_t align_elems = 32 / dtype_size_;
57+
dim_length_align_ =
58+
((static_cast<int64_t>(dim_) + align_elems - 1) / align_elems) *
59+
align_elems;
60+
assert(static_cast<int64_t>(dim_) == dim_length_align_ &&
61+
"Custom AddRmsNorm kernel requires 32-byte aligned last dimension");
62+
63+
total_rows_ =
64+
static_cast<int64_t>(batch_size_) * static_cast<int64_t>(nhead_);
65+
66+
// For fp16 input, weight needs fp32 conversion because the custom
67+
// kernel always reads weight as fp32.
68+
needs_weight_cast_ = (dtype_size_ == 2);
69+
70+
if (needs_weight_cast_) {
71+
// Allocate persistent fp32 weight buffer on device.
72+
size_t fp32_bytes = static_cast<size_t>(dim_) * sizeof(float);
73+
aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY);
74+
75+
// `AclTensorCache` for the cast source (fp16 weight descriptor).
76+
weight_src_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
77+
ACL_FLOAT16, nullptr);
78+
79+
// `AclTensorCache` for the cast destination (fp32 weight buffer).
80+
weight_dst_cache_ = ascend::AclTensorCache({static_cast<int64_t>(dim_)},
81+
ACL_FLOAT, weight_fp32_data_);
82+
}
83+
}
84+
85+
~Operator() {
86+
if (!ascend::IsAclRuntimeAlive()) return;
87+
88+
// Null cached descriptors — see `AclTensorCache::release()`.
89+
weight_src_cache_.release();
90+
weight_dst_cache_.release();
91+
92+
if (weight_fp32_data_) aclrtFree(weight_fp32_data_);
93+
}
94+
95+
void operator()(const Tensor x1, const Tensor x2, const Tensor gamma,
96+
float eps, Tensor y_out, Tensor x_out) const override {
97+
auto stream = static_cast<aclrtStream>(stream_);
98+
99+
// Determine fp32 weight pointer.
100+
void* weight_fp32;
101+
102+
if (needs_weight_cast_) {
103+
// Only re-cast when the weight data pointer changes. Model weights
104+
// are fixed after loading, so this typically runs once on the first
105+
// call and is skipped on all subsequent calls.
106+
const void* cur_weight = gamma.data();
107+
108+
if (cur_weight != last_weight_ptr_) {
109+
auto t_src = weight_src_cache_.get(const_cast<void*>(cur_weight));
110+
auto t_dst = weight_dst_cache_.get(weight_fp32_data_);
111+
112+
if (!cast_exec_) {
113+
aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_,
114+
&cast_exec_);
115+
aclSetAclOpExecutorRepeatable(cast_exec_);
116+
} else {
117+
aclSetInputTensorAddr(cast_exec_, 0, t_src,
118+
const_cast<void*>(cur_weight));
119+
aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_);
120+
}
121+
122+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_);
123+
aclnnCast(arena.buf, cast_ws_, cast_exec_, stream);
124+
last_weight_ptr_ = cur_weight;
125+
}
126+
127+
weight_fp32 = weight_fp32_data_;
128+
} else {
129+
// Input is fp32 — weight is already fp32.
130+
weight_fp32 = const_cast<void*>(gamma.data());
131+
}
132+
133+
// Block-level tiling: distribute rows across cores.
134+
static constexpr int64_t kMaxBlockDim = 40;
135+
int64_t used_cores = std::min(total_rows_, kMaxBlockDim);
136+
int64_t former_length = (total_rows_ + used_cores - 1) / used_cores;
137+
int64_t tail_length = former_length - 1;
138+
int64_t former_num = total_rows_ - tail_length * used_cores;
139+
uint32_t block_dim = static_cast<uint32_t>(used_cores);
140+
141+
// Launch custom AscendC kernel.
142+
aclrtlaunch_add_rms_norm(
143+
block_dim, stream, const_cast<void*>(x1.data()),
144+
const_cast<void*>(x2.data()), weight_fp32, y_out.data(), x_out.data(),
145+
total_rows_, static_cast<int64_t>(dim_), dim_length_align_, former_num,
146+
former_length, tail_length, eps, dtype_size_);
147+
}
148+
149+
private:
150+
int64_t dtype_size_;
151+
152+
int64_t dim_length_align_;
153+
154+
int64_t total_rows_;
155+
156+
bool needs_weight_cast_;
157+
158+
void* weight_fp32_data_ = nullptr;
159+
160+
mutable ascend::AclTensorCache weight_src_cache_;
161+
162+
mutable ascend::AclTensorCache weight_dst_cache_;
163+
164+
mutable const void* last_weight_ptr_ = nullptr;
165+
166+
mutable aclOpExecutor* cast_exec_ = nullptr;
167+
168+
mutable uint64_t cast_ws_ = 0;
169+
};
170+
171+
} // namespace infini::ops
172+
173+
#endif // INFINI_HAS_CUSTOM_KERNELS
174+
#endif // INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_

0 commit comments

Comments
 (0)